Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
This commit is contained in:
commit
c0121d69e7
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
|
@ -344,3 +344,15 @@ jobs:
|
|||
env:
|
||||
COMPLEMENT_BASE_IMAGE: complement-synapse:latest
|
||||
working-directory: complement
|
||||
|
||||
# a job which marks all the other jobs as complete, thus allowing PRs to be merged.
|
||||
tests-done:
|
||||
needs:
|
||||
- trial
|
||||
- trial-olddeps
|
||||
- sytest
|
||||
- portdb
|
||||
- complement
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: "true"
|
1
changelog.d/10332.feature
Normal file
1
changelog.d/10332.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric.
|
1
changelog.d/10348.misc
Normal file
1
changelog.d/10348.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Run `pyupgrade` on the codebase.
|
1
changelog.d/10382.misc
Normal file
1
changelog.d/10382.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert internal type variable syntax to reflect wider ecosystem use.
|
1
changelog.d/10386.removal
Normal file
1
changelog.d/10386.removal
Normal file
|
@ -0,0 +1 @@
|
|||
The third-party event rules module interface is deprecated in favour of the generic module interface introduced in Synapse v1.37.0. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information.
|
1
changelog.d/10404.bugfix
Normal file
1
changelog.d/10404.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Responses from `/make_{join,leave,knock}` no longer include signatures, which will turn out to be invalid after events are returned to `/send_{join,leave,knock}`.
|
1
changelog.d/10414.bugfix
Normal file
1
changelog.d/10414.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a number of logged errors caused by remote servers being down.
|
1
changelog.d/10418.misc
Normal file
1
changelog.d/10418.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert internal type variable syntax to reflect wider ecosystem use.
|
1
changelog.d/10421.misc
Normal file
1
changelog.d/10421.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Remove unused `events_by_room` code (tech debt).
|
1
changelog.d/10427.feature
Normal file
1
changelog.d/10427.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric.
|
1
changelog.d/10430.misc
Normal file
1
changelog.d/10430.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add a github actions job recording success of other jobs.
|
1
changelog.d/9884.feature
Normal file
1
changelog.d/9884.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add a module type for the account validity feature.
|
109
docs/modules.md
109
docs/modules.md
|
@ -63,7 +63,7 @@ Modules can register web resources onto Synapse's web server using the following
|
|||
API method:
|
||||
|
||||
```python
|
||||
def ModuleApi.register_web_resource(path: str, resource: IResource)
|
||||
def ModuleApi.register_web_resource(path: str, resource: IResource) -> None
|
||||
```
|
||||
|
||||
The path is the full absolute path to register the resource at. For example, if you
|
||||
|
@ -91,12 +91,17 @@ are split in categories. A single module may implement callbacks from multiple c
|
|||
and is under no obligation to implement all callbacks from the categories it registers
|
||||
callbacks for.
|
||||
|
||||
Modules can register callbacks using one of the module API's `register_[...]_callbacks`
|
||||
methods. The callback functions are passed to these methods as keyword arguments, with
|
||||
the callback name as the argument name and the function as its value. This is demonstrated
|
||||
in the example below. A `register_[...]_callbacks` method exists for each module type
|
||||
documented in this section.
|
||||
|
||||
#### Spam checker callbacks
|
||||
|
||||
To register one of the callbacks described in this section, a module needs to use the
|
||||
module API's `register_spam_checker_callbacks` method. The callback functions are passed
|
||||
to `register_spam_checker_callbacks` as keyword arguments, with the callback name as the
|
||||
argument name and the function as its value. This is demonstrated in the example below.
|
||||
Spam checker callbacks allow module developers to implement spam mitigation actions for
|
||||
Synapse instances. Spam checker callbacks can be registered using the module API's
|
||||
`register_spam_checker_callbacks` method.
|
||||
|
||||
The available spam checker callbacks are:
|
||||
|
||||
|
@ -115,7 +120,7 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool
|
|||
|
||||
Called when processing an invitation. The module must return a `bool` indicating whether
|
||||
the inviter can invite the invitee to the given room. Both inviter and invitee are
|
||||
represented by their Matrix user ID (i.e. `@alice:example.com`).
|
||||
represented by their Matrix user ID (e.g. `@alice:example.com`).
|
||||
|
||||
```python
|
||||
async def user_may_create_room(user: str) -> bool
|
||||
|
@ -181,13 +186,103 @@ The arguments passed to this callback are:
|
|||
```python
|
||||
async def check_media_file_for_spam(
|
||||
file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper",
|
||||
file_info: "synapse.rest.media.v1._base.FileInfo"
|
||||
file_info: "synapse.rest.media.v1._base.FileInfo",
|
||||
) -> bool
|
||||
```
|
||||
|
||||
Called when storing a local or remote file. The module must return a boolean indicating
|
||||
whether the given file can be stored in the homeserver's media store.
|
||||
|
||||
#### Account validity callbacks
|
||||
|
||||
Account validity callbacks allow module developers to add extra steps to verify the
|
||||
validity on an account, i.e. see if a user can be granted access to their account on the
|
||||
Synapse instance. Account validity callbacks can be registered using the module API's
|
||||
`register_account_validity_callbacks` method.
|
||||
|
||||
The available account validity callbacks are:
|
||||
|
||||
```python
|
||||
async def is_user_expired(user: str) -> Optional[bool]
|
||||
```
|
||||
|
||||
Called when processing any authenticated request (except for logout requests). The module
|
||||
can return a `bool` to indicate whether the user has expired and should be locked out of
|
||||
their account, or `None` if the module wasn't able to figure it out. The user is
|
||||
represented by their Matrix user ID (e.g. `@alice:example.com`).
|
||||
|
||||
If the module returns `True`, the current request will be denied with the error code
|
||||
`ORG_MATRIX_EXPIRED_ACCOUNT` and the HTTP status code 403. Note that this doesn't
|
||||
invalidate the user's access token.
|
||||
|
||||
```python
|
||||
async def on_user_registration(user: str) -> None
|
||||
```
|
||||
|
||||
Called after successfully registering a user, in case the module needs to perform extra
|
||||
operations to keep track of them. (e.g. add them to a database table). The user is
|
||||
represented by their Matrix user ID.
|
||||
|
||||
#### Third party rules callbacks
|
||||
|
||||
Third party rules callbacks allow module developers to add extra checks to verify the
|
||||
validity of incoming events. Third party event rules callbacks can be registered using
|
||||
the module API's `register_third_party_rules_callbacks` method.
|
||||
|
||||
The available third party rules callbacks are:
|
||||
|
||||
```python
|
||||
async def check_event_allowed(
|
||||
event: "synapse.events.EventBase",
|
||||
state_events: "synapse.types.StateMap",
|
||||
) -> Tuple[bool, Optional[dict]]
|
||||
```
|
||||
|
||||
**<span style="color:red">
|
||||
This callback is very experimental and can and will break without notice. Module developers
|
||||
are encouraged to implement `check_event_for_spam` from the spam checker category instead.
|
||||
</span>**
|
||||
|
||||
Called when processing any incoming event, with the event and a `StateMap`
|
||||
representing the current state of the room the event is being sent into. A `StateMap` is
|
||||
a dictionary that maps tuples containing an event type and a state key to the
|
||||
corresponding state event. For example retrieving the room's `m.room.create` event from
|
||||
the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`.
|
||||
The module must return a boolean indicating whether the event can be allowed.
|
||||
|
||||
Note that this callback function processes incoming events coming via federation
|
||||
traffic (on top of client traffic). This means denying an event might cause the local
|
||||
copy of the room's history to diverge from that of remote servers. This may cause
|
||||
federation issues in the room. It is strongly recommended to only deny events using this
|
||||
callback function if the sender is a local user, or in a private federation in which all
|
||||
servers are using the same module, with the same configuration.
|
||||
|
||||
If the boolean returned by the module is `True`, it may also tell Synapse to replace the
|
||||
event with new data by returning the new event's data as a dictionary. In order to do
|
||||
that, it is recommended the module calls `event.get_dict()` to get the current event as a
|
||||
dictionary, and modify the returned dictionary accordingly.
|
||||
|
||||
Note that replacing the event only works for events sent by local users, not for events
|
||||
received over federation.
|
||||
|
||||
```python
|
||||
async def on_create_room(
|
||||
requester: "synapse.types.Requester",
|
||||
request_content: dict,
|
||||
is_requester_admin: bool,
|
||||
) -> None
|
||||
```
|
||||
|
||||
Called when processing a room creation request, with the `Requester` object for the user
|
||||
performing the request, a dictionary representing the room creation request's JSON body
|
||||
(see [the spec](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-createroom)
|
||||
for a list of possible parameters), and a boolean indicating whether the user performing
|
||||
the request is a server admin.
|
||||
|
||||
Modules can modify the `request_content` (by e.g. adding events to its `initial_state`),
|
||||
or deny the room's creation by raising a `module_api.errors.SynapseError`.
|
||||
|
||||
|
||||
### Porting an existing module that uses the old interface
|
||||
|
||||
In order to port a module that uses Synapse's old module interface, its author needs to:
|
||||
|
|
|
@ -1310,91 +1310,6 @@ account_threepid_delegates:
|
|||
#auto_join_rooms_for_guests: false
|
||||
|
||||
|
||||
## Account Validity ##
|
||||
|
||||
# Optional account validity configuration. This allows for accounts to be denied
|
||||
# any request after a given period.
|
||||
#
|
||||
# Once this feature is enabled, Synapse will look for registered users without an
|
||||
# expiration date at startup and will add one to every account it found using the
|
||||
# current settings at that time.
|
||||
# This means that, if a validity period is set, and Synapse is restarted (it will
|
||||
# then derive an expiration date from the current validity period), and some time
|
||||
# after that the validity period changes and Synapse is restarted, the users'
|
||||
# expiration dates won't be updated unless their account is manually renewed. This
|
||||
# date will be randomly selected within a range [now + period - d ; now + period],
|
||||
# where d is equal to 10% of the validity period.
|
||||
#
|
||||
account_validity:
|
||||
# The account validity feature is disabled by default. Uncomment the
|
||||
# following line to enable it.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# The period after which an account is valid after its registration. When
|
||||
# renewing the account, its validity period will be extended by this amount
|
||||
# of time. This parameter is required when using the account validity
|
||||
# feature.
|
||||
#
|
||||
#period: 6w
|
||||
|
||||
# The amount of time before an account's expiry date at which Synapse will
|
||||
# send an email to the account's email address with a renewal link. By
|
||||
# default, no such emails are sent.
|
||||
#
|
||||
# If you enable this setting, you will also need to fill out the 'email' and
|
||||
# 'public_baseurl' configuration sections.
|
||||
#
|
||||
#renew_at: 1w
|
||||
|
||||
# The subject of the email sent out with the renewal link. '%(app)s' can be
|
||||
# used as a placeholder for the 'app_name' parameter from the 'email'
|
||||
# section.
|
||||
#
|
||||
# Note that the placeholder must be written '%(app)s', including the
|
||||
# trailing 's'.
|
||||
#
|
||||
# If this is not set, a default value is used.
|
||||
#
|
||||
#renew_email_subject: "Renew your %(app)s account"
|
||||
|
||||
# Directory in which Synapse will try to find templates for the HTML files to
|
||||
# serve to the user when trying to renew an account. If not set, default
|
||||
# templates from within the Synapse package will be used.
|
||||
#
|
||||
# The currently available templates are:
|
||||
#
|
||||
# * account_renewed.html: Displayed to the user after they have successfully
|
||||
# renewed their account.
|
||||
#
|
||||
# * account_previously_renewed.html: Displayed to the user if they attempt to
|
||||
# renew their account with a token that is valid, but that has already
|
||||
# been used. In this case the account is not renewed again.
|
||||
#
|
||||
# * invalid_token.html: Displayed to the user when they try to renew an account
|
||||
# with an unknown or invalid renewal token.
|
||||
#
|
||||
# See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for
|
||||
# default template contents.
|
||||
#
|
||||
# The file name of some of these templates can be configured below for legacy
|
||||
# reasons.
|
||||
#
|
||||
#template_dir: "res/templates"
|
||||
|
||||
# A custom file name for the 'account_renewed.html' template.
|
||||
#
|
||||
# If not set, the file is assumed to be named "account_renewed.html".
|
||||
#
|
||||
#account_renewed_html_path: "account_renewed.html"
|
||||
|
||||
# A custom file name for the 'invalid_token.html' template.
|
||||
#
|
||||
# If not set, the file is assumed to be named "invalid_token.html".
|
||||
#
|
||||
#invalid_token_html_path: "invalid_token.html"
|
||||
|
||||
|
||||
## Metrics ###
|
||||
|
||||
# Enable collection and rendering of performance metrics
|
||||
|
@ -2739,19 +2654,6 @@ stats:
|
|||
# action: allow
|
||||
|
||||
|
||||
# Server admins can define a Python module that implements extra rules for
|
||||
# allowing or denying incoming events. In order to work, this module needs to
|
||||
# override the methods defined in synapse/events/third_party_rules.py.
|
||||
#
|
||||
# This feature is designed to be used in closed federations only, where each
|
||||
# participating server enforces the same rules.
|
||||
#
|
||||
#third_party_event_rules:
|
||||
# module: "my_custom_project.SuperRulesSet"
|
||||
# config:
|
||||
# example_option: 'things'
|
||||
|
||||
|
||||
## Opentracing ##
|
||||
|
||||
# These settings enable opentracing, which implements distributed tracing.
|
||||
|
|
|
@ -86,6 +86,19 @@ process, for example:
|
|||
```
|
||||
|
||||
|
||||
# Upgrading to v1.39.0
|
||||
|
||||
## Deprecation of the current third-party rules module interface
|
||||
|
||||
The current third-party rules module interface is deprecated in favour of the new generic
|
||||
modules system introduced in Synapse v1.37.0. Authors of third-party rules modules can refer
|
||||
to [this documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface)
|
||||
to update their modules. Synapse administrators can refer to [this documentation](modules.md#using-modules)
|
||||
to update their configuration once the modules they are using have been updated.
|
||||
|
||||
We plan to remove support for the current third-party rules interface in September 2021.
|
||||
|
||||
|
||||
# Upgrading to v1.38.0
|
||||
|
||||
## Re-indexing of `events` table on Postgres databases
|
||||
|
|
|
@ -62,6 +62,7 @@ class Auth:
|
|||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
|
||||
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
|
||||
10000, "token_cache"
|
||||
|
@ -69,9 +70,6 @@ class Auth:
|
|||
|
||||
self._auth_blocking = AuthBlocking(self.hs)
|
||||
|
||||
self._account_validity_enabled = (
|
||||
hs.config.account_validity.account_validity_enabled
|
||||
)
|
||||
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
|
||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
|
||||
|
@ -187,12 +185,17 @@ class Auth:
|
|||
shadow_banned = user_info.shadow_banned
|
||||
|
||||
# Deny the request if the user account has expired.
|
||||
if self._account_validity_enabled and not allow_expired:
|
||||
if await self.store.is_account_expired(
|
||||
user_info.user_id, self.clock.time_msec()
|
||||
if not allow_expired:
|
||||
if await self._account_validity_handler.is_user_expired(
|
||||
user_info.user_id
|
||||
):
|
||||
# Raise the error if either an account validity module has determined
|
||||
# the account has expired, or the legacy account validity
|
||||
# implementation is enabled and determined the account has expired
|
||||
raise AuthError(
|
||||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||
403,
|
||||
"User account has expired",
|
||||
errcode=Codes.EXPIRED_ACCOUNT,
|
||||
)
|
||||
|
||||
device_id = user_info.device_id
|
||||
|
|
|
@ -38,6 +38,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home
|
|||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.events.spamcheck import load_legacy_spam_checkers
|
||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.metrics.jemalloc import setup_jemalloc_stats
|
||||
|
@ -368,6 +369,7 @@ async def start(hs: "HomeServer"):
|
|||
module(config=config, api=module_api)
|
||||
|
||||
load_legacy_spam_checkers(hs)
|
||||
load_legacy_third_party_event_rules(hs)
|
||||
|
||||
# If we've configured an expiry time for caches, start the background job now.
|
||||
setup_expire_lru_cache_entries(hs)
|
||||
|
|
|
@ -395,10 +395,8 @@ class GenericWorkerServer(HomeServer):
|
|||
elif listener.type == "metrics":
|
||||
if not self.config.enable_metrics:
|
||||
logger.warning(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener.bind_addresses, listener.port)
|
||||
|
|
|
@ -305,10 +305,8 @@ class SynapseHomeServer(HomeServer):
|
|||
elif listener.type == "metrics":
|
||||
if not self.config.enable_metrics:
|
||||
logger.warning(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener.bind_addresses, listener.port)
|
||||
|
|
|
@ -71,6 +71,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
|
|||
# General statistics
|
||||
#
|
||||
|
||||
store = hs.get_datastore()
|
||||
|
||||
stats["homeserver"] = hs.config.server_name
|
||||
stats["server_context"] = hs.config.server_context
|
||||
stats["timestamp"] = now
|
||||
|
@ -79,34 +81,38 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
|
|||
stats["python_version"] = "{}.{}.{}".format(
|
||||
version.major, version.minor, version.micro
|
||||
)
|
||||
stats["total_users"] = await hs.get_datastore().count_all_users()
|
||||
stats["total_users"] = await store.count_all_users()
|
||||
|
||||
total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
|
||||
total_nonbridged_users = await store.count_nonbridged_users()
|
||||
stats["total_nonbridged_users"] = total_nonbridged_users
|
||||
|
||||
daily_user_type_results = await hs.get_datastore().count_daily_user_type()
|
||||
daily_user_type_results = await store.count_daily_user_type()
|
||||
for name, count in daily_user_type_results.items():
|
||||
stats["daily_user_type_" + name] = count
|
||||
|
||||
room_count = await hs.get_datastore().get_room_count()
|
||||
room_count = await store.get_room_count()
|
||||
stats["total_room_count"] = room_count
|
||||
|
||||
stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
|
||||
stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
|
||||
daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
|
||||
stats["daily_active_users"] = await store.count_daily_users()
|
||||
stats["monthly_active_users"] = await store.count_monthly_users()
|
||||
daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms()
|
||||
stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
|
||||
stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
|
||||
daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
|
||||
stats["daily_e2ee_messages"] = await store.count_daily_e2ee_messages()
|
||||
daily_sent_e2ee_messages = await store.count_daily_sent_e2ee_messages()
|
||||
stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
|
||||
stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
|
||||
stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
|
||||
daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
|
||||
stats["daily_active_rooms"] = await store.count_daily_active_rooms()
|
||||
stats["daily_messages"] = await store.count_daily_messages()
|
||||
daily_sent_messages = await store.count_daily_sent_messages()
|
||||
stats["daily_sent_messages"] = daily_sent_messages
|
||||
|
||||
r30_results = await hs.get_datastore().count_r30_users()
|
||||
r30_results = await store.count_r30_users()
|
||||
for name, count in r30_results.items():
|
||||
stats["r30_users_" + name] = count
|
||||
|
||||
r30v2_results = await store.count_r30_users()
|
||||
for name, count in r30v2_results.items():
|
||||
stats["r30v2_users_" + name] = count
|
||||
|
||||
stats["cache_factor"] = hs.config.caches.global_factor
|
||||
stats["event_cache_size"] = hs.config.caches.event_cache_size
|
||||
|
||||
|
@ -115,8 +121,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
|
|||
#
|
||||
|
||||
# This only reports info about the *main* database.
|
||||
stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
|
||||
stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
|
||||
stats["database_engine"] = store.db_pool.engine.module.__name__
|
||||
stats["database_server_version"] = store.db_pool.engine.server_version
|
||||
|
||||
#
|
||||
# Logging configuration
|
||||
|
|
|
@ -18,6 +18,21 @@ class AccountValidityConfig(Config):
|
|||
section = "account_validity"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
"""Parses the old account validity config. The config format looks like this:
|
||||
|
||||
account_validity:
|
||||
enabled: true
|
||||
period: 6w
|
||||
renew_at: 1w
|
||||
renew_email_subject: "Renew your %(app)s account"
|
||||
template_dir: "res/templates"
|
||||
account_renewed_html_path: "account_renewed.html"
|
||||
invalid_token_html_path: "invalid_token.html"
|
||||
|
||||
We expect admins to use modules for this feature (which is why it doesn't appear
|
||||
in the sample config file), but we want to keep support for it around for a bit
|
||||
for backwards compatibility.
|
||||
"""
|
||||
account_validity_config = config.get("account_validity") or {}
|
||||
self.account_validity_enabled = account_validity_config.get("enabled", False)
|
||||
self.account_validity_renew_by_email_enabled = (
|
||||
|
@ -75,90 +90,3 @@ class AccountValidityConfig(Config):
|
|||
],
|
||||
account_validity_template_dir,
|
||||
)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
## Account Validity ##
|
||||
|
||||
# Optional account validity configuration. This allows for accounts to be denied
|
||||
# any request after a given period.
|
||||
#
|
||||
# Once this feature is enabled, Synapse will look for registered users without an
|
||||
# expiration date at startup and will add one to every account it found using the
|
||||
# current settings at that time.
|
||||
# This means that, if a validity period is set, and Synapse is restarted (it will
|
||||
# then derive an expiration date from the current validity period), and some time
|
||||
# after that the validity period changes and Synapse is restarted, the users'
|
||||
# expiration dates won't be updated unless their account is manually renewed. This
|
||||
# date will be randomly selected within a range [now + period - d ; now + period],
|
||||
# where d is equal to 10% of the validity period.
|
||||
#
|
||||
account_validity:
|
||||
# The account validity feature is disabled by default. Uncomment the
|
||||
# following line to enable it.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# The period after which an account is valid after its registration. When
|
||||
# renewing the account, its validity period will be extended by this amount
|
||||
# of time. This parameter is required when using the account validity
|
||||
# feature.
|
||||
#
|
||||
#period: 6w
|
||||
|
||||
# The amount of time before an account's expiry date at which Synapse will
|
||||
# send an email to the account's email address with a renewal link. By
|
||||
# default, no such emails are sent.
|
||||
#
|
||||
# If you enable this setting, you will also need to fill out the 'email' and
|
||||
# 'public_baseurl' configuration sections.
|
||||
#
|
||||
#renew_at: 1w
|
||||
|
||||
# The subject of the email sent out with the renewal link. '%(app)s' can be
|
||||
# used as a placeholder for the 'app_name' parameter from the 'email'
|
||||
# section.
|
||||
#
|
||||
# Note that the placeholder must be written '%(app)s', including the
|
||||
# trailing 's'.
|
||||
#
|
||||
# If this is not set, a default value is used.
|
||||
#
|
||||
#renew_email_subject: "Renew your %(app)s account"
|
||||
|
||||
# Directory in which Synapse will try to find templates for the HTML files to
|
||||
# serve to the user when trying to renew an account. If not set, default
|
||||
# templates from within the Synapse package will be used.
|
||||
#
|
||||
# The currently available templates are:
|
||||
#
|
||||
# * account_renewed.html: Displayed to the user after they have successfully
|
||||
# renewed their account.
|
||||
#
|
||||
# * account_previously_renewed.html: Displayed to the user if they attempt to
|
||||
# renew their account with a token that is valid, but that has already
|
||||
# been used. In this case the account is not renewed again.
|
||||
#
|
||||
# * invalid_token.html: Displayed to the user when they try to renew an account
|
||||
# with an unknown or invalid renewal token.
|
||||
#
|
||||
# See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for
|
||||
# default template contents.
|
||||
#
|
||||
# The file name of some of these templates can be configured below for legacy
|
||||
# reasons.
|
||||
#
|
||||
#template_dir: "res/templates"
|
||||
|
||||
# A custom file name for the 'account_renewed.html' template.
|
||||
#
|
||||
# If not set, the file is assumed to be named "account_renewed.html".
|
||||
#
|
||||
#account_renewed_html_path: "account_renewed.html"
|
||||
|
||||
# A custom file name for the 'invalid_token.html' template.
|
||||
#
|
||||
# If not set, the file is assumed to be named "invalid_token.html".
|
||||
#
|
||||
#invalid_token_html_path: "invalid_token.html"
|
||||
"""
|
||||
|
|
|
@ -64,7 +64,7 @@ def load_appservices(hostname, config_files):
|
|||
|
||||
for config_file in config_files:
|
||||
try:
|
||||
with open(config_file, "r") as f:
|
||||
with open(config_file) as f:
|
||||
appservice = _load_appservice(hostname, yaml.safe_load(f), config_file)
|
||||
if appservice.id in seen_ids:
|
||||
raise ConfigError(
|
||||
|
|
|
@ -28,18 +28,3 @@ class ThirdPartyRulesConfig(Config):
|
|||
self.third_party_event_rules = load_module(
|
||||
provider, ("third_party_event_rules",)
|
||||
)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
# Server admins can define a Python module that implements extra rules for
|
||||
# allowing or denying incoming events. In order to work, this module needs to
|
||||
# override the methods defined in synapse/events/third_party_rules.py.
|
||||
#
|
||||
# This feature is designed to be used in closed federations only, where each
|
||||
# participating server enforces the same rules.
|
||||
#
|
||||
#third_party_event_rules:
|
||||
# module: "my_custom_project.SuperRulesSet"
|
||||
# config:
|
||||
# example_option: 'things'
|
||||
"""
|
||||
|
|
|
@ -66,10 +66,8 @@ class TlsConfig(Config):
|
|||
if self.federation_client_minimum_tls_version == "1.3":
|
||||
if getattr(SSL, "OP_NO_TLSv1_3", None) is None:
|
||||
raise ConfigError(
|
||||
(
|
||||
"federation_client_minimum_tls_version cannot be 1.3, "
|
||||
"your OpenSSL does not support it"
|
||||
)
|
||||
"federation_client_minimum_tls_version cannot be 1.3, "
|
||||
"your OpenSSL does not support it"
|
||||
)
|
||||
|
||||
# Whitelist of domains to not verify certificates for
|
||||
|
|
|
@ -291,6 +291,20 @@ class EventBase(metaclass=abc.ABCMeta):
|
|||
|
||||
return pdu_json
|
||||
|
||||
def get_templated_pdu_json(self) -> JsonDict:
|
||||
"""
|
||||
Return a JSON object suitable for a templated event, as used in the
|
||||
make_{join,leave,knock} workflow.
|
||||
"""
|
||||
# By using _dict directly we don't pull in signatures/unsigned.
|
||||
template_json = dict(self._dict)
|
||||
# The hashes (similar to the signature) need to be recalculated by the
|
||||
# joining/leaving/knocking server after (potentially) modifying the
|
||||
# event.
|
||||
template_json.pop("hashes")
|
||||
|
||||
return template_json
|
||||
|
||||
def __set__(self, instance, value):
|
||||
raise AttributeError("Unrecognized attribute %s" % (instance,))
|
||||
|
||||
|
|
|
@ -11,16 +11,124 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import Requester, StateMap
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CHECK_EVENT_ALLOWED_CALLBACK = Callable[
|
||||
[EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
|
||||
]
|
||||
ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
|
||||
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
|
||||
[str, str, StateMap[EventBase]], Awaitable[bool]
|
||||
]
|
||||
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
|
||||
[str, StateMap[EventBase], str], Awaitable[bool]
|
||||
]
|
||||
|
||||
|
||||
def load_legacy_third_party_event_rules(hs: "HomeServer"):
|
||||
"""Wrapper that loads a third party event rules module configured using the old
|
||||
configuration, and registers the hooks they implement.
|
||||
"""
|
||||
if hs.config.third_party_event_rules is None:
|
||||
return
|
||||
|
||||
module, config = hs.config.third_party_event_rules
|
||||
|
||||
api = hs.get_module_api()
|
||||
third_party_rules = module(config=config, module_api=api)
|
||||
|
||||
# The known hooks. If a module implements a method which name appears in this set,
|
||||
# we'll want to register it.
|
||||
third_party_event_rules_methods = {
|
||||
"check_event_allowed",
|
||||
"on_create_room",
|
||||
"check_threepid_can_be_invited",
|
||||
"check_visibility_can_be_modified",
|
||||
}
|
||||
|
||||
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
|
||||
# f might be None if the callback isn't implemented by the module. In this
|
||||
# case we don't want to register a callback at all so we return None.
|
||||
if f is None:
|
||||
return None
|
||||
|
||||
# We return a separate wrapper for these methods because, in order to wrap them
|
||||
# correctly, we need to await its result. Therefore it doesn't make a lot of
|
||||
# sense to make it go through the run() wrapper.
|
||||
if f.__name__ == "check_event_allowed":
|
||||
|
||||
# We need to wrap check_event_allowed because its old form would return either
|
||||
# a boolean or a dict, but now we want to return the dict separately from the
|
||||
# boolean.
|
||||
async def wrap_check_event_allowed(
|
||||
event: EventBase,
|
||||
state_events: StateMap[EventBase],
|
||||
) -> Tuple[bool, Optional[dict]]:
|
||||
# We've already made sure f is not None above, but mypy doesn't do well
|
||||
# across function boundaries so we need to tell it f is definitely not
|
||||
# None.
|
||||
assert f is not None
|
||||
|
||||
res = await f(event, state_events)
|
||||
if isinstance(res, dict):
|
||||
return True, res
|
||||
else:
|
||||
return res, None
|
||||
|
||||
return wrap_check_event_allowed
|
||||
|
||||
if f.__name__ == "on_create_room":
|
||||
|
||||
# We need to wrap on_create_room because its old form would return a boolean
|
||||
# if the room creation is denied, but now we just want it to raise an
|
||||
# exception.
|
||||
async def wrap_on_create_room(
|
||||
requester: Requester, config: dict, is_requester_admin: bool
|
||||
) -> None:
|
||||
# We've already made sure f is not None above, but mypy doesn't do well
|
||||
# across function boundaries so we need to tell it f is definitely not
|
||||
# None.
|
||||
assert f is not None
|
||||
|
||||
res = await f(requester, config, is_requester_admin)
|
||||
if res is False:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Room creation forbidden with these parameters",
|
||||
)
|
||||
|
||||
return wrap_on_create_room
|
||||
|
||||
def run(*args, **kwargs):
|
||||
# mypy doesn't do well across function boundaries so we need to tell it
|
||||
# f is definitely not None.
|
||||
assert f is not None
|
||||
|
||||
return maybe_awaitable(f(*args, **kwargs))
|
||||
|
||||
return run
|
||||
|
||||
# Register the hooks through the module API.
|
||||
hooks = {
|
||||
hook: async_wrapper(getattr(third_party_rules, hook, None))
|
||||
for hook in third_party_event_rules_methods
|
||||
}
|
||||
|
||||
api.register_third_party_rules_callbacks(**hooks)
|
||||
|
||||
|
||||
class ThirdPartyEventRules:
|
||||
"""Allows server admins to provide a Python module implementing an extra
|
||||
|
@ -35,36 +143,65 @@ class ThirdPartyEventRules:
|
|||
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
module = None
|
||||
config = None
|
||||
if hs.config.third_party_event_rules:
|
||||
module, config = hs.config.third_party_event_rules
|
||||
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
|
||||
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
|
||||
self._check_threepid_can_be_invited_callbacks: List[
|
||||
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
|
||||
] = []
|
||||
self._check_visibility_can_be_modified_callbacks: List[
|
||||
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
|
||||
] = []
|
||||
|
||||
if module is not None:
|
||||
self.third_party_rules = module(
|
||||
config=config,
|
||||
module_api=hs.get_module_api(),
|
||||
def register_third_party_rules_callbacks(
|
||||
self,
|
||||
check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
|
||||
on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
|
||||
check_threepid_can_be_invited: Optional[
|
||||
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
|
||||
] = None,
|
||||
check_visibility_can_be_modified: Optional[
|
||||
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
|
||||
] = None,
|
||||
):
|
||||
"""Register callbacks from modules for each hook."""
|
||||
if check_event_allowed is not None:
|
||||
self._check_event_allowed_callbacks.append(check_event_allowed)
|
||||
|
||||
if on_create_room is not None:
|
||||
self._on_create_room_callbacks.append(on_create_room)
|
||||
|
||||
if check_threepid_can_be_invited is not None:
|
||||
self._check_threepid_can_be_invited_callbacks.append(
|
||||
check_threepid_can_be_invited,
|
||||
)
|
||||
|
||||
if check_visibility_can_be_modified is not None:
|
||||
self._check_visibility_can_be_modified_callbacks.append(
|
||||
check_visibility_can_be_modified,
|
||||
)
|
||||
|
||||
async def check_event_allowed(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Union[bool, dict]:
|
||||
) -> Tuple[bool, Optional[dict]]:
|
||||
"""Check if a provided event should be allowed in the given context.
|
||||
|
||||
The module can return:
|
||||
* True: the event is allowed.
|
||||
* False: the event is not allowed, and should be rejected with M_FORBIDDEN.
|
||||
* a dict: replacement event data.
|
||||
|
||||
If the event is allowed, the module can also return a dictionary to use as a
|
||||
replacement for the event.
|
||||
|
||||
Args:
|
||||
event: The event to be checked.
|
||||
context: The context of the event.
|
||||
|
||||
Returns:
|
||||
The result from the ThirdPartyRules module, as above
|
||||
The result from the ThirdPartyRules module, as above.
|
||||
"""
|
||||
if self.third_party_rules is None:
|
||||
return True
|
||||
# Bail out early without hitting the store if we don't have any callbacks to run.
|
||||
if len(self._check_event_allowed_callbacks) == 0:
|
||||
return True, None
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
|
||||
|
@ -77,29 +214,46 @@ class ThirdPartyEventRules:
|
|||
# the hashes and signatures.
|
||||
event.freeze()
|
||||
|
||||
return await self.third_party_rules.check_event_allowed(event, state_events)
|
||||
for callback in self._check_event_allowed_callbacks:
|
||||
try:
|
||||
res, replacement_data = await callback(event, state_events)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||
continue
|
||||
|
||||
# Return if the event shouldn't be allowed or if the module came up with a
|
||||
# replacement dict for the event.
|
||||
if res is False:
|
||||
return res, None
|
||||
elif isinstance(replacement_data, dict):
|
||||
return True, replacement_data
|
||||
|
||||
return True, None
|
||||
|
||||
async def on_create_room(
|
||||
self, requester: Requester, config: dict, is_requester_admin: bool
|
||||
) -> bool:
|
||||
"""Intercept requests to create room to allow, deny or update the
|
||||
request config.
|
||||
) -> None:
|
||||
"""Intercept requests to create room to maybe deny it (via an exception) or
|
||||
update the request config.
|
||||
|
||||
Args:
|
||||
requester
|
||||
config: The creation config from the client.
|
||||
is_requester_admin: If the requester is an admin
|
||||
|
||||
Returns:
|
||||
Whether room creation is allowed or denied.
|
||||
"""
|
||||
for callback in self._on_create_room_callbacks:
|
||||
try:
|
||||
await callback(requester, config, is_requester_admin)
|
||||
except Exception as e:
|
||||
# Don't silence the errors raised by this callback since we expect it to
|
||||
# raise an exception to deny the creation of the room; instead make sure
|
||||
# it's a SynapseError we can send to clients.
|
||||
if not isinstance(e, SynapseError):
|
||||
e = SynapseError(
|
||||
403, "Room creation forbidden with these parameters"
|
||||
)
|
||||
|
||||
if self.third_party_rules is None:
|
||||
return True
|
||||
|
||||
return await self.third_party_rules.on_create_room(
|
||||
requester, config, is_requester_admin
|
||||
)
|
||||
raise e
|
||||
|
||||
async def check_threepid_can_be_invited(
|
||||
self, medium: str, address: str, room_id: str
|
||||
|
@ -114,15 +268,20 @@ class ThirdPartyEventRules:
|
|||
Returns:
|
||||
True if the 3PID can be invited, False if not.
|
||||
"""
|
||||
|
||||
if self.third_party_rules is None:
|
||||
# Bail out early without hitting the store if we don't have any callbacks to run.
|
||||
if len(self._check_threepid_can_be_invited_callbacks) == 0:
|
||||
return True
|
||||
|
||||
state_events = await self._get_state_map_for_room(room_id)
|
||||
|
||||
return await self.third_party_rules.check_threepid_can_be_invited(
|
||||
medium, address, state_events
|
||||
)
|
||||
for callback in self._check_threepid_can_be_invited_callbacks:
|
||||
try:
|
||||
if await callback(medium, address, state_events) is False:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||
|
||||
return True
|
||||
|
||||
async def check_visibility_can_be_modified(
|
||||
self, room_id: str, new_visibility: str
|
||||
|
@ -137,18 +296,20 @@ class ThirdPartyEventRules:
|
|||
Returns:
|
||||
True if the room's visibility can be modified, False if not.
|
||||
"""
|
||||
if self.third_party_rules is None:
|
||||
return True
|
||||
|
||||
check_func = getattr(
|
||||
self.third_party_rules, "check_visibility_can_be_modified", None
|
||||
)
|
||||
if not check_func or not callable(check_func):
|
||||
# Bail out early without hitting the store if we don't have any callback
|
||||
if len(self._check_visibility_can_be_modified_callbacks) == 0:
|
||||
return True
|
||||
|
||||
state_events = await self._get_state_map_for_room(room_id)
|
||||
|
||||
return await check_func(room_id, state_events, new_visibility)
|
||||
for callback in self._check_visibility_can_be_modified_callbacks:
|
||||
try:
|
||||
if await callback(room_id, state_events, new_visibility) is False:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||
|
||||
return True
|
||||
|
||||
async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
|
||||
"""Given a room ID, return the state events of that room.
|
||||
|
|
|
@ -562,8 +562,7 @@ class FederationServer(FederationBase):
|
|||
raise IncompatibleRoomVersionError(room_version=room_version)
|
||||
|
||||
pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
|
||||
time_now = self._clock.time_msec()
|
||||
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
||||
return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
|
||||
|
||||
async def on_invite_request(
|
||||
self, origin: str, content: JsonDict, room_version_id: str
|
||||
|
@ -611,8 +610,7 @@ class FederationServer(FederationBase):
|
|||
|
||||
room_version = await self.store.get_room_version_id(room_id)
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
||||
return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
|
||||
|
||||
async def on_send_leave_request(
|
||||
self, origin: str, content: JsonDict, room_id: str
|
||||
|
@ -659,9 +657,8 @@ class FederationServer(FederationBase):
|
|||
)
|
||||
|
||||
pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
|
||||
time_now = self._clock.time_msec()
|
||||
return {
|
||||
"event": pdu.get_pdu_json(time_now),
|
||||
"event": pdu.get_templated_pdu_json(),
|
||||
"room_version": room_version.identifier,
|
||||
}
|
||||
|
||||
|
|
|
@ -38,10 +38,10 @@ class BaseHandler:
|
|||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore() # type: synapse.storage.DataStore
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.distributor = hs.get_distributor()
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
@ -55,12 +55,12 @@ class BaseHandler:
|
|||
# Check whether ratelimiting room admin message redaction is enabled
|
||||
# by the presence of rate limits in the config
|
||||
if self.hs.config.rc_admin_redaction:
|
||||
self.admin_redaction_ratelimiter = Ratelimiter(
|
||||
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
||||
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
||||
) # type: Optional[Ratelimiter]
|
||||
)
|
||||
else:
|
||||
self.admin_redaction_ratelimiter = None
|
||||
|
||||
|
|
|
@ -15,9 +15,11 @@
|
|||
import email.mime.multipart
|
||||
import email.utils
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import StoreError, SynapseError
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.types import UserID
|
||||
from synapse.util import stringutils
|
||||
|
@ -27,6 +29,15 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Types for callbacks to be registered via the module api
|
||||
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
|
||||
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
|
||||
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
|
||||
# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
|
||||
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
|
||||
ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
|
||||
ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
|
||||
|
||||
|
||||
class AccountValidityHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -70,6 +81,99 @@ class AccountValidityHandler:
|
|||
if hs.config.run_background_tasks:
|
||||
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
|
||||
|
||||
self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
|
||||
self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
|
||||
self._on_legacy_send_mail_callback: Optional[
|
||||
ON_LEGACY_SEND_MAIL_CALLBACK
|
||||
] = None
|
||||
self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
|
||||
|
||||
# The legacy admin requests callback isn't a protected attribute because we need
|
||||
# to access it from the admin servlet, which is outside of this handler.
|
||||
self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None
|
||||
|
||||
def register_account_validity_callbacks(
|
||||
self,
|
||||
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
|
||||
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
|
||||
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
|
||||
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
|
||||
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
|
||||
):
|
||||
"""Register callbacks from module for each hook."""
|
||||
if is_user_expired is not None:
|
||||
self._is_user_expired_callbacks.append(is_user_expired)
|
||||
|
||||
if on_user_registration is not None:
|
||||
self._on_user_registration_callbacks.append(on_user_registration)
|
||||
|
||||
# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
|
||||
# an admin one). As part of moving the feature into a module, we need to change
|
||||
# the path from /_matrix/client/unstable/account_validity/... to
|
||||
# /_synapse/client/account_validity, because:
|
||||
#
|
||||
# * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
|
||||
# * the way we register servlets means that modules can't register resources
|
||||
# under /_matrix/client
|
||||
#
|
||||
# We need to allow for a transition period between the old and new endpoints
|
||||
# in order to allow for clients to update (and for emails to be processed).
|
||||
#
|
||||
# Once the email-account-validity module is loaded, it will take control of account
|
||||
# validity by moving the rows from our `account_validity` table into its own table.
|
||||
#
|
||||
# Therefore, we need to allow modules (in practice just the one implementing the
|
||||
# email-based account validity) to temporarily hook into the legacy endpoints so we
|
||||
# can route the traffic coming into the old endpoints into the module, which is
|
||||
# why we have the following three temporary hooks.
|
||||
if on_legacy_send_mail is not None:
|
||||
if self._on_legacy_send_mail_callback is not None:
|
||||
raise RuntimeError("Tried to register on_legacy_send_mail twice")
|
||||
|
||||
self._on_legacy_send_mail_callback = on_legacy_send_mail
|
||||
|
||||
if on_legacy_renew is not None:
|
||||
if self._on_legacy_renew_callback is not None:
|
||||
raise RuntimeError("Tried to register on_legacy_renew twice")
|
||||
|
||||
self._on_legacy_renew_callback = on_legacy_renew
|
||||
|
||||
if on_legacy_admin_request is not None:
|
||||
if self.on_legacy_admin_request_callback is not None:
|
||||
raise RuntimeError("Tried to register on_legacy_admin_request twice")
|
||||
|
||||
self.on_legacy_admin_request_callback = on_legacy_admin_request
|
||||
|
||||
async def is_user_expired(self, user_id: str) -> bool:
|
||||
"""Checks if a user has expired against third-party modules.
|
||||
|
||||
Args:
|
||||
user_id: The user to check the expiry of.
|
||||
|
||||
Returns:
|
||||
Whether the user has expired.
|
||||
"""
|
||||
for callback in self._is_user_expired_callbacks:
|
||||
expired = await callback(user_id)
|
||||
if expired is not None:
|
||||
return expired
|
||||
|
||||
if self._account_validity_enabled:
|
||||
# If no module could determine whether the user has expired and the legacy
|
||||
# configuration is enabled, fall back to it.
|
||||
return await self.store.is_account_expired(user_id, self.clock.time_msec())
|
||||
|
||||
return False
|
||||
|
||||
async def on_user_registration(self, user_id: str):
|
||||
"""Tell third-party modules about a user's registration.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the newly registered user.
|
||||
"""
|
||||
for callback in self._on_user_registration_callbacks:
|
||||
await callback(user_id)
|
||||
|
||||
@wrap_as_background_process("send_renewals")
|
||||
async def _send_renewal_emails(self) -> None:
|
||||
"""Gets the list of users whose account is expiring in the amount of time
|
||||
|
@ -95,6 +199,17 @@ class AccountValidityHandler:
|
|||
Raises:
|
||||
SynapseError if the user is not set to renew.
|
||||
"""
|
||||
# If a module supports sending a renewal email from here, do that, otherwise do
|
||||
# the legacy dance.
|
||||
if self._on_legacy_send_mail_callback is not None:
|
||||
await self._on_legacy_send_mail_callback(user_id)
|
||||
return
|
||||
|
||||
if not self._account_validity_renew_by_email_enabled:
|
||||
raise AuthError(
|
||||
403, "Account renewal via email is disabled on this server."
|
||||
)
|
||||
|
||||
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
||||
|
||||
# If this user isn't set to be expired, raise an error.
|
||||
|
@ -209,6 +324,10 @@ class AccountValidityHandler:
|
|||
token is considered stale. A token is stale if the 'token_used_ts_ms' db column
|
||||
is non-null.
|
||||
|
||||
This method exists to support handling the legacy account validity /renew
|
||||
endpoint. If a module implements the on_legacy_renew callback, then this process
|
||||
is delegated to the module instead.
|
||||
|
||||
Args:
|
||||
renewal_token: Token sent with the renewal request.
|
||||
Returns:
|
||||
|
@ -218,6 +337,11 @@ class AccountValidityHandler:
|
|||
* An int representing the user's expiry timestamp as milliseconds since the
|
||||
epoch, or 0 if the token was invalid.
|
||||
"""
|
||||
# If a module supports triggering a renew from here, do that, otherwise do the
|
||||
# legacy dance.
|
||||
if self._on_legacy_renew_callback is not None:
|
||||
return await self._on_legacy_renew_callback(renewal_token)
|
||||
|
||||
try:
|
||||
(
|
||||
user_id,
|
||||
|
|
|
@ -139,7 +139,7 @@ class AdminHandler(BaseHandler):
|
|||
to_key = RoomStreamToken(None, stream_ordering)
|
||||
|
||||
# Events that we've processed in this room
|
||||
written_events = set() # type: Set[str]
|
||||
written_events: Set[str] = set()
|
||||
|
||||
# We need to track gaps in the events stream so that we can then
|
||||
# write out the state at those events. We do this by keeping track
|
||||
|
@ -152,7 +152,7 @@ class AdminHandler(BaseHandler):
|
|||
# The reverse mapping to above, i.e. map from unseen event to events
|
||||
# that have the unseen event in their prev_events, i.e. the unseen
|
||||
# events "children".
|
||||
unseen_to_child_events = {} # type: Dict[str, Set[str]]
|
||||
unseen_to_child_events: Dict[str, Set[str]] = {}
|
||||
|
||||
# We fetch events in the room the user could see by fetching *all*
|
||||
# events that we have and then filtering, this isn't the most
|
||||
|
|
|
@ -96,7 +96,7 @@ class ApplicationServicesHandler:
|
|||
self.current_max, limit
|
||||
)
|
||||
|
||||
events_by_room = {} # type: Dict[str, List[EventBase]]
|
||||
events_by_room: Dict[str, List[EventBase]] = {}
|
||||
for event in events:
|
||||
events_by_room.setdefault(event.room_id, []).append(event)
|
||||
|
||||
|
@ -275,7 +275,7 @@ class ApplicationServicesHandler:
|
|||
async def _handle_presence(
|
||||
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
||||
) -> List[JsonDict]:
|
||||
events = [] # type: List[JsonDict]
|
||||
events: List[JsonDict] = []
|
||||
presence_source = self.event_sources.sources["presence"]
|
||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||
service, "presence"
|
||||
|
@ -375,7 +375,7 @@ class ApplicationServicesHandler:
|
|||
self, only_protocol: Optional[str] = None
|
||||
) -> Dict[str, JsonDict]:
|
||||
services = self.store.get_app_services()
|
||||
protocols = {} # type: Dict[str, List[JsonDict]]
|
||||
protocols: Dict[str, List[JsonDict]] = {}
|
||||
|
||||
# Collect up all the individual protocol responses out of the ASes
|
||||
for s in services:
|
||||
|
|
|
@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
|
||||
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
|
||||
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
|
||||
inst = auth_checker_class(hs)
|
||||
if inst.is_enabled():
|
||||
|
@ -296,7 +296,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# A mapping of user ID to extra attributes to include in the login
|
||||
# response.
|
||||
self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
|
||||
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
|
||||
|
||||
async def validate_user_via_ui_auth(
|
||||
self,
|
||||
|
@ -500,7 +500,7 @@ class AuthHandler(BaseHandler):
|
|||
all the stages in any of the permitted flows.
|
||||
"""
|
||||
|
||||
sid = None # type: Optional[str]
|
||||
sid: Optional[str] = None
|
||||
authdict = clientdict.pop("auth", {})
|
||||
if "session" in authdict:
|
||||
sid = authdict["session"]
|
||||
|
@ -588,9 +588,9 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
|
||||
# check auth type currently being presented
|
||||
errordict = {} # type: Dict[str, Any]
|
||||
errordict: Dict[str, Any] = {}
|
||||
if "type" in authdict:
|
||||
login_type = authdict["type"] # type: str
|
||||
login_type: str = authdict["type"]
|
||||
try:
|
||||
result = await self._check_auth_dict(authdict, clientip)
|
||||
if result:
|
||||
|
@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
|
|||
LoginType.TERMS: self._get_params_terms,
|
||||
}
|
||||
|
||||
params = {} # type: Dict[str, Any]
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
for f in public_flows:
|
||||
for stage in f:
|
||||
|
@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler):
|
|||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
user_id_to_verify = await self.get_session_data(
|
||||
user_id_to_verify: str = await self.get_session_data(
|
||||
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
|
||||
) # type: str
|
||||
)
|
||||
|
||||
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
|
||||
user_id_to_verify
|
||||
|
|
|
@ -40,7 +40,7 @@ class CasError(Exception):
|
|||
|
||||
def __str__(self):
|
||||
if self.error_description:
|
||||
return "{}: {}".format(self.error, self.error_description)
|
||||
return f"{self.error}: {self.error_description}"
|
||||
return self.error
|
||||
|
||||
|
||||
|
@ -171,7 +171,7 @@ class CasHandler:
|
|||
|
||||
# Iterate through the nodes and pull out the user and any extra attributes.
|
||||
user = None
|
||||
attributes = {} # type: Dict[str, List[Optional[str]]]
|
||||
attributes: Dict[str, List[Optional[str]]] = {}
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
|
|
|
@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
user_id
|
||||
)
|
||||
|
||||
hosts = set() # type: Set[str]
|
||||
hosts: Set[str] = set()
|
||||
if self.hs.is_mine_id(user_id):
|
||||
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
|
||||
hosts.discard(self.server_name)
|
||||
|
@ -613,20 +613,20 @@ class DeviceListUpdater:
|
|||
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
|
||||
|
||||
# user_id -> list of updates waiting to be handled.
|
||||
self._pending_updates = (
|
||||
{}
|
||||
) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
|
||||
self._pending_updates: Dict[
|
||||
str, List[Tuple[str, str, Iterable[str], JsonDict]]
|
||||
] = {}
|
||||
|
||||
# Recently seen stream ids. We don't bother keeping these in the DB,
|
||||
# but they're useful to have them about to reduce the number of spurious
|
||||
# resyncs.
|
||||
self._seen_updates = ExpiringCache(
|
||||
self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
|
||||
cache_name="device_update_edu",
|
||||
clock=self.clock,
|
||||
max_len=10000,
|
||||
expiry_ms=30 * 60 * 1000,
|
||||
iterable=True,
|
||||
) # type: ExpiringCache[str, Set[str]]
|
||||
)
|
||||
|
||||
# Attempt to resync out of sync device lists every 30s.
|
||||
self._resync_retry_in_progress = False
|
||||
|
@ -755,7 +755,7 @@ class DeviceListUpdater:
|
|||
"""Given a list of updates for a user figure out if we need to do a full
|
||||
resync, or whether we have enough data that we can just apply the delta.
|
||||
"""
|
||||
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
|
||||
seen_updates: Set[str] = self._seen_updates.get(user_id, set())
|
||||
|
||||
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
|
||||
|
||||
|
|
|
@ -203,7 +203,7 @@ class DeviceMessageHandler:
|
|||
log_kv({"number_of_to_device_messages": len(messages)})
|
||||
set_tag("sender", sender_user_id)
|
||||
local_messages = {}
|
||||
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
|
||||
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for user_id, by_device in messages.items():
|
||||
# Ratelimit local cross-user key requests by the sending device.
|
||||
if (
|
||||
|
|
|
@ -237,9 +237,9 @@ class DirectoryHandler(BaseHandler):
|
|||
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
|
||||
room_id = None
|
||||
if self.hs.is_mine(room_alias):
|
||||
result = await self.get_association_from_room_alias(
|
||||
room_alias
|
||||
) # type: Optional[RoomAliasMapping]
|
||||
result: Optional[
|
||||
RoomAliasMapping
|
||||
] = await self.get_association_from_room_alias(room_alias)
|
||||
|
||||
if result:
|
||||
room_id = result.room_id
|
||||
|
|
|
@ -115,9 +115,9 @@ class E2eKeysHandler:
|
|||
the number of in-flight queries at a time.
|
||||
"""
|
||||
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
||||
device_keys_query = query_body.get(
|
||||
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
|
||||
"device_keys", {}
|
||||
) # type: Dict[str, Iterable[str]]
|
||||
)
|
||||
|
||||
# separate users by domain.
|
||||
# make a map from domain to user_id to device_ids
|
||||
|
@ -136,7 +136,7 @@ class E2eKeysHandler:
|
|||
|
||||
# First get local devices.
|
||||
# A map of destination -> failure response.
|
||||
failures = {} # type: Dict[str, JsonDict]
|
||||
failures: Dict[str, JsonDict] = {}
|
||||
results = {}
|
||||
if local_query:
|
||||
local_result = await self.query_local_devices(local_query)
|
||||
|
@ -151,11 +151,9 @@ class E2eKeysHandler:
|
|||
|
||||
# Now attempt to get any remote devices from our local cache.
|
||||
# A map of destination -> user ID -> device IDs.
|
||||
remote_queries_not_in_cache = (
|
||||
{}
|
||||
) # type: Dict[str, Dict[str, Iterable[str]]]
|
||||
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
|
||||
if remote_queries:
|
||||
query_list = [] # type: List[Tuple[str, Optional[str]]]
|
||||
query_list: List[Tuple[str, Optional[str]]] = []
|
||||
for user_id, device_ids in remote_queries.items():
|
||||
if device_ids:
|
||||
query_list.extend(
|
||||
|
@ -362,9 +360,9 @@ class E2eKeysHandler:
|
|||
A map from user_id -> device_id -> device details
|
||||
"""
|
||||
set_tag("local_query", query)
|
||||
local_query = [] # type: List[Tuple[str, Optional[str]]]
|
||||
local_query: List[Tuple[str, Optional[str]]] = []
|
||||
|
||||
result_dict = {} # type: Dict[str, Dict[str, dict]]
|
||||
result_dict: Dict[str, Dict[str, dict]] = {}
|
||||
for user_id, device_ids in query.items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if not self.is_mine(UserID.from_string(user_id)):
|
||||
|
@ -402,9 +400,9 @@ class E2eKeysHandler:
|
|||
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
|
||||
) -> JsonDict:
|
||||
"""Handle a device key query from a federated server"""
|
||||
device_keys_query = query_body.get(
|
||||
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
|
||||
"device_keys", {}
|
||||
) # type: Dict[str, Optional[List[str]]]
|
||||
)
|
||||
res = await self.query_local_devices(device_keys_query)
|
||||
ret = {"device_keys": res}
|
||||
|
||||
|
@ -421,8 +419,8 @@ class E2eKeysHandler:
|
|||
async def claim_one_time_keys(
|
||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
|
||||
) -> JsonDict:
|
||||
local_query = [] # type: List[Tuple[str, str, str]]
|
||||
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
|
||||
local_query: List[Tuple[str, str, str]] = []
|
||||
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
|
||||
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
|
@ -439,8 +437,8 @@ class E2eKeysHandler:
|
|||
results = await self.store.claim_e2e_one_time_keys(local_query)
|
||||
|
||||
# A map of user ID -> device ID -> key ID -> key.
|
||||
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
|
||||
failures = {} # type: Dict[str, JsonDict]
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
failures: Dict[str, JsonDict] = {}
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, json_str in keys.items():
|
||||
|
@ -768,8 +766,8 @@ class E2eKeysHandler:
|
|||
Raises:
|
||||
SynapseError: if the input is malformed
|
||||
"""
|
||||
signature_list = [] # type: List[SignatureListItem]
|
||||
failures = {} # type: Dict[str, Dict[str, JsonDict]]
|
||||
signature_list: List["SignatureListItem"] = []
|
||||
failures: Dict[str, Dict[str, JsonDict]] = {}
|
||||
if not signatures:
|
||||
return signature_list, failures
|
||||
|
||||
|
@ -930,8 +928,8 @@ class E2eKeysHandler:
|
|||
Raises:
|
||||
SynapseError: if the input is malformed
|
||||
"""
|
||||
signature_list = [] # type: List[SignatureListItem]
|
||||
failures = {} # type: Dict[str, Dict[str, JsonDict]]
|
||||
signature_list: List["SignatureListItem"] = []
|
||||
failures: Dict[str, Dict[str, JsonDict]] = {}
|
||||
if not signatures:
|
||||
return signature_list, failures
|
||||
|
||||
|
@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
|
|||
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||
|
||||
# user_id -> list of updates waiting to be handled.
|
||||
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
|
||||
self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
|
||||
|
||||
async def incoming_signing_key_update(
|
||||
self, origin: str, edu_content: JsonDict
|
||||
|
@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
|
|||
# This can happen since we batch updates
|
||||
return
|
||||
|
||||
device_ids = [] # type: List[str]
|
||||
device_ids: List[str] = []
|
||||
|
||||
logger.info("pending updates: %r", pending_updates)
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler):
|
|||
|
||||
# When the user joins a new room, or another user joins a currently
|
||||
# joined room, we need to send down presence for those users.
|
||||
to_add = [] # type: List[JsonDict]
|
||||
to_add: List[JsonDict] = []
|
||||
for event in events:
|
||||
if not isinstance(event, EventBase):
|
||||
continue
|
||||
|
@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler):
|
|||
# Send down presence.
|
||||
if event.state_key == auth_user_id:
|
||||
# Send down presence for everyone in the room.
|
||||
users = await self.store.get_users_in_room(
|
||||
users: Iterable[str] = await self.store.get_users_in_room(
|
||||
event.room_id
|
||||
) # type: Iterable[str]
|
||||
)
|
||||
else:
|
||||
users = [event.state_key]
|
||||
|
||||
|
|
|
@ -181,7 +181,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# When joining a room we need to queue any events for that room up.
|
||||
# For each room, a list of (pdu, origin) tuples.
|
||||
self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]]
|
||||
self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
|
||||
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
|
||||
|
||||
self._room_backfill = Linearizer("room_backfill")
|
||||
|
@ -368,7 +368,7 @@ class FederationHandler(BaseHandler):
|
|||
ours = await self.state_store.get_state_groups_ids(room_id, seen)
|
||||
|
||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||
state_maps = list(ours.values()) # type: List[StateMap[str]]
|
||||
state_maps: List[StateMap[str]] = list(ours.values())
|
||||
|
||||
# we don't need this any more, let's delete it.
|
||||
del ours
|
||||
|
@ -735,7 +735,7 @@ class FederationHandler(BaseHandler):
|
|||
# we need to make sure we re-load from the database to get the rejected
|
||||
# state correct.
|
||||
fetched_events.update(
|
||||
(await self.store.get_events(missing_desired_events, allow_rejected=True))
|
||||
await self.store.get_events(missing_desired_events, allow_rejected=True)
|
||||
)
|
||||
|
||||
# check for events which were in the wrong room.
|
||||
|
@ -845,7 +845,7 @@ class FederationHandler(BaseHandler):
|
|||
# exact key to expect. Otherwise check it matches any key we
|
||||
# have for that device.
|
||||
|
||||
current_keys = [] # type: Container[str]
|
||||
current_keys: Container[str] = []
|
||||
|
||||
if device:
|
||||
keys = device.get("keys", {}).get("keys", {})
|
||||
|
@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler):
|
|||
if e_type == EventTypes.Member and event.membership == Membership.JOIN
|
||||
]
|
||||
|
||||
joined_domains = {} # type: Dict[str, int]
|
||||
joined_domains: Dict[str, int] = {}
|
||||
for u, d in joined_users:
|
||||
try:
|
||||
dom = get_domain_from_id(u)
|
||||
|
@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
event_map = {} # type: Dict[str, EventBase]
|
||||
event_map: Dict[str, EventBase] = {}
|
||||
|
||||
async def get_event(event_id: str):
|
||||
with nested_logging_context(event_id):
|
||||
|
@ -1596,7 +1596,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# Ask the remote server to create a valid knock event for us. Once received,
|
||||
# we sign the event
|
||||
params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
|
||||
params: Dict[str, Iterable[str]] = {"ver": supported_room_versions}
|
||||
origin, event, event_format_version = await self._make_and_verify_event(
|
||||
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
|
||||
)
|
||||
|
@ -1934,7 +1934,7 @@ class FederationHandler(BaseHandler):
|
|||
builder=builder
|
||||
)
|
||||
|
||||
event_allowed = await self.third_party_event_rules.check_event_allowed(
|
||||
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
|
||||
event, context
|
||||
)
|
||||
if not event_allowed:
|
||||
|
@ -2026,7 +2026,7 @@ class FederationHandler(BaseHandler):
|
|||
# for knock events, we run the third-party event rules. It's not entirely clear
|
||||
# why we don't do this for other sorts of membership events.
|
||||
if event.membership == Membership.KNOCK:
|
||||
event_allowed = await self.third_party_event_rules.check_event_allowed(
|
||||
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
|
||||
event, context
|
||||
)
|
||||
if not event_allowed:
|
||||
|
@ -2453,14 +2453,14 @@ class FederationHandler(BaseHandler):
|
|||
state_sets_d = await self.state_store.get_state_groups(
|
||||
event.room_id, extrem_ids
|
||||
)
|
||||
state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]]
|
||||
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
|
||||
state_sets.append(state)
|
||||
current_states = await self.state_handler.resolve_events(
|
||||
room_version, state_sets, event
|
||||
)
|
||||
current_state_ids = {
|
||||
current_state_ids: StateMap[str] = {
|
||||
k: e.event_id for k, e in current_states.items()
|
||||
} # type: StateMap[str]
|
||||
}
|
||||
else:
|
||||
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
|
@ -2817,7 +2817,7 @@ class FederationHandler(BaseHandler):
|
|||
"""
|
||||
# exclude the state key of the new event from the current_state in the context.
|
||||
if event.is_state():
|
||||
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
|
||||
event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
|
||||
else:
|
||||
event_key = None
|
||||
state_updates = {
|
||||
|
@ -3156,7 +3156,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
logger.debug("Checking auth on event %r", event.content)
|
||||
|
||||
last_exception = None # type: Optional[Exception]
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
# for each public key in the 3pid invite event
|
||||
for public_key_object in event_auth.get_public_keys(invite_event):
|
||||
|
|
|
@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler:
|
|||
async def bulk_get_publicised_groups(
|
||||
self, user_ids: Iterable[str], proxy: bool = True
|
||||
) -> JsonDict:
|
||||
destinations = {} # type: Dict[str, Set[str]]
|
||||
destinations: Dict[str, Set[str]] = {}
|
||||
local_users = set()
|
||||
|
||||
for user_id in user_ids:
|
||||
|
@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler:
|
|||
raise SynapseError(400, "Some user_ids are not local")
|
||||
|
||||
results = {}
|
||||
failed_results = [] # type: List[str]
|
||||
failed_results: List[str] = []
|
||||
for destination, dest_user_ids in destinations.items():
|
||||
try:
|
||||
r = await self.transport_client.bulk_get_publicised_groups(
|
||||
|
|
|
@ -302,7 +302,7 @@ class IdentityHandler(BaseHandler):
|
|||
)
|
||||
|
||||
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
|
||||
url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
|
||||
url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
|
||||
|
||||
content = {
|
||||
"mxid": mxid,
|
||||
|
@ -695,7 +695,7 @@ class IdentityHandler(BaseHandler):
|
|||
return data["mxid"]
|
||||
except RequestTimedOutError:
|
||||
raise SynapseError(500, "Timed out contacting identity server")
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
logger.warning("Error from v1 identity server lookup: %s" % (e,))
|
||||
|
||||
return None
|
||||
|
|
|
@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler):
|
|||
self.state = hs.get_state_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
self.snapshot_cache = ResponseCache(
|
||||
hs.get_clock(), "initial_sync_cache"
|
||||
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
|
||||
self.snapshot_cache: ResponseCache[
|
||||
Tuple[
|
||||
str,
|
||||
Optional[StreamToken],
|
||||
Optional[StreamToken],
|
||||
str,
|
||||
Optional[int],
|
||||
bool,
|
||||
bool,
|
||||
]
|
||||
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
|
|
@ -81,7 +81,7 @@ class MessageHandler:
|
|||
|
||||
# The scheduled call to self._expire_event. None if no call is currently
|
||||
# scheduled.
|
||||
self._scheduled_expiry = None # type: Optional[IDelayedCall]
|
||||
self._scheduled_expiry: Optional[IDelayedCall] = None
|
||||
|
||||
if not hs.config.worker_app:
|
||||
run_as_background_process(
|
||||
|
@ -196,9 +196,7 @@ class MessageHandler:
|
|||
room_state_events = await self.state_store.get_state_for_events(
|
||||
[event.event_id], state_filter=state_filter
|
||||
)
|
||||
room_state = room_state_events[
|
||||
event.event_id
|
||||
] # type: Mapping[Any, EventBase]
|
||||
room_state: Mapping[Any, EventBase] = room_state_events[event.event_id]
|
||||
else:
|
||||
raise AuthError(
|
||||
403,
|
||||
|
@ -421,9 +419,9 @@ class EventCreationHandler:
|
|||
self.action_generator = hs.get_action_generator()
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self.third_party_event_rules = (
|
||||
self.third_party_event_rules: "ThirdPartyEventRules" = (
|
||||
self.hs.get_third_party_event_rules()
|
||||
) # type: ThirdPartyEventRules
|
||||
)
|
||||
|
||||
self._block_events_without_consent_error = (
|
||||
self.config.block_events_without_consent_error
|
||||
|
@ -440,7 +438,7 @@ class EventCreationHandler:
|
|||
#
|
||||
# map from room id to time-of-last-attempt.
|
||||
#
|
||||
self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
|
||||
self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
|
||||
# The number of forward extremeities before a dummy event is sent.
|
||||
self._dummy_events_threshold = hs.config.dummy_events_threshold
|
||||
|
||||
|
@ -465,9 +463,7 @@ class EventCreationHandler:
|
|||
# Stores the state groups we've recently added to the joined hosts
|
||||
# external cache. Note that the timeout must be significantly less than
|
||||
# the TTL on the external cache.
|
||||
self._external_cache_joined_hosts_updates = (
|
||||
None
|
||||
) # type: Optional[ExpiringCache]
|
||||
self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
|
||||
if self._external_cache.is_enabled():
|
||||
self._external_cache_joined_hosts_updates = ExpiringCache(
|
||||
"_external_cache_joined_hosts_updates",
|
||||
|
@ -953,10 +949,10 @@ class EventCreationHandler:
|
|||
if requester:
|
||||
context.app_service = requester.app_service
|
||||
|
||||
third_party_result = await self.third_party_event_rules.check_event_allowed(
|
||||
res, new_content = await self.third_party_event_rules.check_event_allowed(
|
||||
event, context
|
||||
)
|
||||
if not third_party_result:
|
||||
if res is False:
|
||||
logger.info(
|
||||
"Event %s forbidden by third-party rules",
|
||||
event,
|
||||
|
@ -964,11 +960,11 @@ class EventCreationHandler:
|
|||
raise SynapseError(
|
||||
403, "This event is not allowed in this context", Codes.FORBIDDEN
|
||||
)
|
||||
elif isinstance(third_party_result, dict):
|
||||
elif new_content is not None:
|
||||
# the third-party rules want to replace the event. We'll need to build a new
|
||||
# event.
|
||||
event, context = await self._rebuild_event_after_third_party_rules(
|
||||
third_party_result, event
|
||||
new_content, event
|
||||
)
|
||||
|
||||
self.validator.validate_new(event, self.config)
|
||||
|
@ -1299,7 +1295,7 @@ class EventCreationHandler:
|
|||
# Validate a newly added alias or newly added alt_aliases.
|
||||
|
||||
original_alias = None
|
||||
original_alt_aliases = [] # type: List[str]
|
||||
original_alt_aliases: List[str] = []
|
||||
|
||||
original_event_id = event.unsigned.get("replaces_state")
|
||||
if original_event_id:
|
||||
|
|
|
@ -72,26 +72,26 @@ _SESSION_COOKIES = [
|
|||
(b"oidc_session_no_samesite", b"HttpOnly"),
|
||||
]
|
||||
|
||||
|
||||
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
|
||||
#: OpenID.Core sec 3.1.3.3.
|
||||
Token = TypedDict(
|
||||
"Token",
|
||||
{
|
||||
"access_token": str,
|
||||
"token_type": str,
|
||||
"id_token": Optional[str],
|
||||
"refresh_token": Optional[str],
|
||||
"expires_in": int,
|
||||
"scope": Optional[str],
|
||||
},
|
||||
)
|
||||
class Token(TypedDict):
|
||||
access_token: str
|
||||
token_type: str
|
||||
id_token: Optional[str]
|
||||
refresh_token: Optional[str]
|
||||
expires_in: int
|
||||
scope: Optional[str]
|
||||
|
||||
|
||||
#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
|
||||
#: there is no real point of doing this in our case.
|
||||
JWK = Dict[str, str]
|
||||
|
||||
|
||||
#: A JWK Set, as per RFC7517 sec 5.
|
||||
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
|
||||
class JWKS(TypedDict):
|
||||
keys: List[JWK]
|
||||
|
||||
|
||||
class OidcHandler:
|
||||
|
@ -105,9 +105,9 @@ class OidcHandler:
|
|||
assert provider_confs
|
||||
|
||||
self._token_generator = OidcSessionTokenGenerator(hs)
|
||||
self._providers = {
|
||||
self._providers: Dict[str, "OidcProvider"] = {
|
||||
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
|
||||
} # type: Dict[str, OidcProvider]
|
||||
}
|
||||
|
||||
async def load_metadata(self) -> None:
|
||||
"""Validate the config and load the metadata from the remote endpoint.
|
||||
|
@ -178,7 +178,7 @@ class OidcHandler:
|
|||
# are two.
|
||||
|
||||
for cookie_name, _ in _SESSION_COOKIES:
|
||||
session = request.getCookie(cookie_name) # type: Optional[bytes]
|
||||
session: Optional[bytes] = request.getCookie(cookie_name)
|
||||
if session is not None:
|
||||
break
|
||||
else:
|
||||
|
@ -255,7 +255,7 @@ class OidcError(Exception):
|
|||
|
||||
def __str__(self):
|
||||
if self.error_description:
|
||||
return "{}: {}".format(self.error, self.error_description)
|
||||
return f"{self.error}: {self.error_description}"
|
||||
return self.error
|
||||
|
||||
|
||||
|
@ -277,7 +277,7 @@ class OidcProvider:
|
|||
self._token_generator = token_generator
|
||||
|
||||
self._config = provider
|
||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||
self._callback_url: str = hs.config.oidc_callback_url
|
||||
|
||||
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
|
||||
# We'll insert this into the Path= parameter of any session cookies we set.
|
||||
|
@ -290,7 +290,7 @@ class OidcProvider:
|
|||
self._scopes = provider.scopes
|
||||
self._user_profile_method = provider.user_profile_method
|
||||
|
||||
client_secret = None # type: Union[None, str, JwtClientSecret]
|
||||
client_secret: Optional[Union[str, JwtClientSecret]] = None
|
||||
if provider.client_secret:
|
||||
client_secret = provider.client_secret
|
||||
elif provider.client_secret_jwt_key:
|
||||
|
@ -305,7 +305,7 @@ class OidcProvider:
|
|||
provider.client_id,
|
||||
client_secret,
|
||||
provider.client_auth_method,
|
||||
) # type: ClientAuth
|
||||
)
|
||||
self._client_auth_method = provider.client_auth_method
|
||||
|
||||
# cache of metadata for the identity provider (endpoint uris, mostly). This is
|
||||
|
@ -324,7 +324,7 @@ class OidcProvider:
|
|||
self._allow_existing_users = provider.allow_existing_users
|
||||
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._server_name = hs.config.server_name # type: str
|
||||
self._server_name: str = hs.config.server_name
|
||||
|
||||
# identifier for the external_ids table
|
||||
self.idp_id = provider.idp_id
|
||||
|
@ -639,7 +639,7 @@ class OidcProvider:
|
|||
)
|
||||
logger.warning(description)
|
||||
# Body was still valid JSON. Might be useful to log it for debugging.
|
||||
logger.warning("Code exchange response: {resp!r}".format(resp=resp))
|
||||
logger.warning("Code exchange response: %r", resp)
|
||||
raise OidcError("server_error", description)
|
||||
|
||||
return resp
|
||||
|
@ -1217,10 +1217,12 @@ class OidcSessionData:
|
|||
ui_auth_session_id = attr.ib(type=str)
|
||||
|
||||
|
||||
UserAttributeDict = TypedDict(
|
||||
"UserAttributeDict",
|
||||
{"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
|
||||
)
|
||||
class UserAttributeDict(TypedDict):
|
||||
localpart: Optional[str]
|
||||
display_name: Optional[str]
|
||||
emails: List[str]
|
||||
|
||||
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
|
@ -1381,7 +1383,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
if display_name == "":
|
||||
display_name = None
|
||||
|
||||
emails = [] # type: List[str]
|
||||
emails: List[str] = []
|
||||
email = render_template_field(self._config.email_template)
|
||||
if email:
|
||||
emails.append(email)
|
||||
|
@ -1391,7 +1393,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
)
|
||||
|
||||
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
||||
extras = {} # type: Dict[str, str]
|
||||
extras: Dict[str, str] = {}
|
||||
for key, template in self._config.extra_attributes.items():
|
||||
try:
|
||||
extras[key] = template.render(user=userinfo).strip()
|
||||
|
|
|
@ -81,9 +81,9 @@ class PaginationHandler:
|
|||
self._server_name = hs.hostname
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
self._purges_in_progress_by_room = set() # type: Set[str]
|
||||
self._purges_in_progress_by_room: Set[str] = set()
|
||||
# map from purge id to PurgeStatus
|
||||
self._purges_by_id = {} # type: Dict[str, PurgeStatus]
|
||||
self._purges_by_id: Dict[str, PurgeStatus] = {}
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
|
||||
|
|
|
@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
|||
|
||||
# The number of ongoing syncs on this process, by user id.
|
||||
# Empty if _presence_enabled is false.
|
||||
self._user_to_num_current_syncs = {} # type: Dict[str, int]
|
||||
self._user_to_num_current_syncs: Dict[str, int] = {}
|
||||
|
||||
self.notifier = hs.get_notifier()
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
# user_id -> last_sync_ms. Lists the users that have stopped syncing but
|
||||
# we haven't notified the presence writer of that yet
|
||||
self.users_going_offline = {} # type: Dict[str, int]
|
||||
self.users_going_offline: Dict[str, int] = {}
|
||||
|
||||
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
|
||||
self._set_state_client = ReplicationPresenceSetState.make_client(hs)
|
||||
|
@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
|
||||
# Set of users who have presence in the `user_to_current_state` that
|
||||
# have not yet been persisted
|
||||
self.unpersisted_users_changes = set() # type: Set[str]
|
||||
self.unpersisted_users_changes: Set[str] = set()
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger(
|
||||
"before",
|
||||
|
@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
|
||||
# Keeps track of the number of *ongoing* syncs on this process. While
|
||||
# this is non zero a user will never go offline.
|
||||
self.user_to_num_current_syncs = {} # type: Dict[str, int]
|
||||
self.user_to_num_current_syncs: Dict[str, int] = {}
|
||||
|
||||
# Keeps track of the number of *ongoing* syncs on other processes.
|
||||
# While any sync is ongoing on another process the user will never
|
||||
|
@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler):
|
|||
# we assume that all the sync requests on that process have stopped.
|
||||
# Stored as a dict from process_id to set of user_id, and a dict of
|
||||
# process_id to millisecond timestamp last updated.
|
||||
self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]]
|
||||
self.external_process_last_updated_ms = {} # type: Dict[str, int]
|
||||
self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
|
||||
self.external_process_last_updated_ms: Dict[str, int] = {}
|
||||
|
||||
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
|
||||
|
||||
|
@ -1581,9 +1581,7 @@ class PresenceEventSource:
|
|||
|
||||
# The set of users that we're interested in and that have had a presence update.
|
||||
# We'll actually pull the presence updates for these users at the end.
|
||||
interested_and_updated_users = (
|
||||
set()
|
||||
) # type: Union[Set[str], FrozenSet[str]]
|
||||
interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
|
||||
|
||||
if from_key:
|
||||
# First get all users that have had a presence update
|
||||
|
@ -1950,8 +1948,8 @@ async def get_interested_parties(
|
|||
A 2-tuple of `(room_ids_to_states, users_to_states)`,
|
||||
with each item being a dict of `entity_name` -> `[UserPresenceState]`
|
||||
"""
|
||||
room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
|
||||
users_to_states = {} # type: Dict[str, List[UserPresenceState]]
|
||||
room_ids_to_states: Dict[str, List[UserPresenceState]] = {}
|
||||
users_to_states: Dict[str, List[UserPresenceState]] = {}
|
||||
for state in states:
|
||||
room_ids = await store.get_rooms_for_user(state.user_id)
|
||||
for room_id in room_ids:
|
||||
|
@ -2063,12 +2061,12 @@ class PresenceFederationQueue:
|
|||
# stream_id, destinations, user_ids)`. We don't store the full states
|
||||
# for efficiency, and remote workers will already have the full states
|
||||
# cached.
|
||||
self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]]
|
||||
self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
|
||||
|
||||
self._next_id = 1
|
||||
|
||||
# Map from instance name to current token
|
||||
self._current_tokens = {} # type: Dict[str, int]
|
||||
self._current_tokens: Dict[str, int] = {}
|
||||
|
||||
if self._queue_presence_updates:
|
||||
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
|
||||
|
@ -2168,7 +2166,7 @@ class PresenceFederationQueue:
|
|||
# handle the case where `from_token` stream ID has already been dropped.
|
||||
start_idx = max(from_token + 1 - self._next_id, -len(self._queue))
|
||||
|
||||
to_send = [] # type: List[Tuple[int, Tuple[str, str]]]
|
||||
to_send: List[Tuple[int, Tuple[str, str]]] = []
|
||||
limited = False
|
||||
new_id = upto_token
|
||||
for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
|
||||
|
@ -2216,7 +2214,7 @@ class PresenceFederationQueue:
|
|||
if not self._federation:
|
||||
return
|
||||
|
||||
hosts_to_users = {} # type: Dict[str, Set[str]]
|
||||
hosts_to_users: Dict[str, Set[str]] = {}
|
||||
for row in rows:
|
||||
hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
|
||||
|
||||
|
|
|
@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler):
|
|||
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
||||
)
|
||||
|
||||
displayname_to_set = new_displayname # type: Optional[str]
|
||||
displayname_to_set: Optional[str] = new_displayname
|
||||
if new_displayname == "":
|
||||
displayname_to_set = None
|
||||
|
||||
|
@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler):
|
|||
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
|
||||
)
|
||||
|
||||
avatar_url_to_set = new_avatar_url # type: Optional[str]
|
||||
avatar_url_to_set: Optional[str] = new_avatar_url
|
||||
if new_avatar_url == "":
|
||||
avatar_url_to_set = None
|
||||
|
||||
|
|
|
@ -98,8 +98,8 @@ class ReceiptsHandler(BaseHandler):
|
|||
|
||||
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
|
||||
"""Takes a list of receipts, stores them and informs the notifier."""
|
||||
min_batch_id = None # type: Optional[int]
|
||||
max_batch_id = None # type: Optional[int]
|
||||
min_batch_id: Optional[int] = None
|
||||
max_batch_id: Optional[int] = None
|
||||
|
||||
for receipt in receipts:
|
||||
res = await self.store.insert_receipt(
|
||||
|
|
|
@ -55,15 +55,12 @@ login_counter = Counter(
|
|||
["guest", "auth_provider"],
|
||||
)
|
||||
|
||||
LoginDict = TypedDict(
|
||||
"LoginDict",
|
||||
{
|
||||
"device_id": str,
|
||||
"access_token": str,
|
||||
"valid_until_ms": Optional[int],
|
||||
"refresh_token": Optional[str],
|
||||
},
|
||||
)
|
||||
|
||||
class LoginDict(TypedDict):
|
||||
device_id: str
|
||||
access_token: str
|
||||
valid_until_ms: Optional[int]
|
||||
refresh_token: Optional[str]
|
||||
|
||||
|
||||
class RegistrationHandler(BaseHandler):
|
||||
|
@ -77,6 +74,7 @@ class RegistrationHandler(BaseHandler):
|
|||
self.identity_handler = self.hs.get_identity_handler()
|
||||
self.ratelimiter = hs.get_registration_ratelimiter()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||
self._server_name = hs.hostname
|
||||
|
||||
|
@ -700,6 +698,10 @@ class RegistrationHandler(BaseHandler):
|
|||
shadow_banned=shadow_banned,
|
||||
)
|
||||
|
||||
# Only call the account validity module(s) on the main process, to avoid
|
||||
# repeating e.g. database writes on all of the workers.
|
||||
await self._account_validity_handler.on_user_registration(user_id)
|
||||
|
||||
async def register_device(
|
||||
self,
|
||||
user_id: str,
|
||||
|
|
|
@ -87,7 +87,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
self.config = hs.config
|
||||
|
||||
# Room state based off defined presets
|
||||
self._presets_dict = {
|
||||
self._presets_dict: Dict[str, Dict[str, Any]] = {
|
||||
RoomCreationPreset.PRIVATE_CHAT: {
|
||||
"join_rules": JoinRules.INVITE,
|
||||
"history_visibility": HistoryVisibility.SHARED,
|
||||
|
@ -109,7 +109,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
"guest_can_join": False,
|
||||
"power_level_content_override": {},
|
||||
},
|
||||
} # type: Dict[str, Dict[str, Any]]
|
||||
}
|
||||
|
||||
# Modify presets to selectively enable encryption by default per homeserver config
|
||||
for preset_name, preset_config in self._presets_dict.items():
|
||||
|
@ -127,9 +127,9 @@ class RoomCreationHandler(BaseHandler):
|
|||
# If a user tries to update the same room multiple times in quick
|
||||
# succession, only process the first attempt and return its result to
|
||||
# subsequent requests
|
||||
self._upgrade_response_cache = ResponseCache(
|
||||
self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
|
||||
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
|
||||
) # type: ResponseCache[Tuple[str, str]]
|
||||
)
|
||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||
|
||||
self.third_party_event_rules = hs.get_third_party_event_rules()
|
||||
|
@ -377,10 +377,10 @@ class RoomCreationHandler(BaseHandler):
|
|||
if not await self.spam_checker.user_may_create_room(user_id):
|
||||
raise SynapseError(403, "You are not permitted to create rooms")
|
||||
|
||||
creation_content = {
|
||||
creation_content: JsonDict = {
|
||||
"room_version": new_room_version.identifier,
|
||||
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
|
||||
} # type: JsonDict
|
||||
}
|
||||
|
||||
# Check if old room was non-federatable
|
||||
|
||||
|
@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler):
|
|||
else:
|
||||
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||
|
||||
# Check whether the third party rules allows/changes the room create
|
||||
# request.
|
||||
event_allowed = await self.third_party_event_rules.on_create_room(
|
||||
# Let the third party rules modify the room creation config if needed, or abort
|
||||
# the room creation entirely with an exception.
|
||||
await self.third_party_event_rules.on_create_room(
|
||||
requester, config, is_requester_admin=is_requester_admin
|
||||
)
|
||||
if not event_allowed:
|
||||
raise SynapseError(
|
||||
403, "You are not permitted to create rooms", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
|
||||
user_id
|
||||
|
@ -936,7 +932,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
etype=EventTypes.PowerLevels, content=pl_content
|
||||
)
|
||||
else:
|
||||
power_level_content = {
|
||||
power_level_content: JsonDict = {
|
||||
"users": {creator_id: 100},
|
||||
"users_default": 0,
|
||||
"events": {
|
||||
|
@ -955,7 +951,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
"kick": 50,
|
||||
"redact": 50,
|
||||
"invite": 50,
|
||||
} # type: JsonDict
|
||||
}
|
||||
|
||||
if config["original_invitees_have_ops"]:
|
||||
for invitee in invite_list:
|
||||
|
|
|
@ -48,12 +48,12 @@ class RoomListHandler(BaseHandler):
|
|||
super().__init__(hs)
|
||||
self.enable_room_list_search = hs.config.enable_room_list_search
|
||||
|
||||
self.response_cache = ResponseCache(
|
||||
hs.get_clock(), "room_list"
|
||||
) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
|
||||
self.remote_response_cache = ResponseCache(
|
||||
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
|
||||
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
|
||||
self.response_cache: ResponseCache[
|
||||
Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
|
||||
] = ResponseCache(hs.get_clock(), "room_list")
|
||||
self.remote_response_cache: ResponseCache[
|
||||
Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
|
||||
] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
|
||||
|
||||
async def get_local_public_room_list(
|
||||
self,
|
||||
|
@ -140,10 +140,10 @@ class RoomListHandler(BaseHandler):
|
|||
if since_token:
|
||||
batch_token = RoomListNextBatch.from_token(since_token)
|
||||
|
||||
bounds = (
|
||||
bounds: Optional[Tuple[int, str]] = (
|
||||
batch_token.last_joined_members,
|
||||
batch_token.last_room_id,
|
||||
) # type: Optional[Tuple[int, str]]
|
||||
)
|
||||
forwards = batch_token.direction_is_forward
|
||||
has_batch_token = True
|
||||
else:
|
||||
|
@ -183,7 +183,7 @@ class RoomListHandler(BaseHandler):
|
|||
|
||||
results = [build_room_entry(r) for r in results]
|
||||
|
||||
response = {} # type: JsonDict
|
||||
response: JsonDict = {}
|
||||
num_results = len(results)
|
||||
if limit is not None:
|
||||
more_to_come = num_results == probing_limit
|
||||
|
@ -384,7 +384,11 @@ class RoomListHandler(BaseHandler):
|
|||
):
|
||||
logger.debug("Falling back to locally-filtered /publicRooms")
|
||||
else:
|
||||
raise # Not an error that should trigger a fallback.
|
||||
# Not an error that should trigger a fallback.
|
||||
raise SynapseError(502, "Failed to fetch room list")
|
||||
except RequestSendFailed:
|
||||
# Not an error that should trigger a fallback.
|
||||
raise SynapseError(502, "Failed to fetch room list")
|
||||
|
||||
# if we reach this point, then we fall back to the situation where
|
||||
# we currently don't support searching across federation, so we have
|
||||
|
|
|
@ -83,7 +83,7 @@ class SamlHandler(BaseHandler):
|
|||
self.unstable_idp_brand = None
|
||||
|
||||
# a map from saml session id to Saml2SessionData object
|
||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||
self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
|
||||
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
self._sso_handler.register_identity_provider(self)
|
||||
|
@ -372,7 +372,7 @@ class SamlHandler(BaseHandler):
|
|||
|
||||
|
||||
DOT_REPLACE_PATTERN = re.compile(
|
||||
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
|
||||
"[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)
|
||||
)
|
||||
|
||||
|
||||
|
@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str:
|
|||
return username
|
||||
|
||||
|
||||
MXID_MAPPER_MAP = {
|
||||
MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
|
||||
"hexencode": map_username_to_mxid_localpart,
|
||||
"dotreplace": dot_replace_for_mxid,
|
||||
} # type: Dict[str, Callable[[str], str]]
|
||||
}
|
||||
|
||||
|
||||
@attr.s
|
||||
|
|
|
@ -192,7 +192,7 @@ class SearchHandler(BaseHandler):
|
|||
# If doing a subset of all rooms seearch, check if any of the rooms
|
||||
# are from an upgraded room, and search their contents as well
|
||||
if search_filter.rooms:
|
||||
historical_room_ids = [] # type: List[str]
|
||||
historical_room_ids: List[str] = []
|
||||
for room_id in search_filter.rooms:
|
||||
# Add any previous rooms to the search if they exist
|
||||
ids = await self.get_old_rooms_from_upgraded_room(room_id)
|
||||
|
@ -216,9 +216,9 @@ class SearchHandler(BaseHandler):
|
|||
rank_map = {} # event_id -> rank of event
|
||||
allowed_events = []
|
||||
# Holds result of grouping by room, if applicable
|
||||
room_groups = {} # type: Dict[str, JsonDict]
|
||||
room_groups: Dict[str, JsonDict] = {}
|
||||
# Holds result of grouping by sender, if applicable
|
||||
sender_group = {} # type: Dict[str, JsonDict]
|
||||
sender_group: Dict[str, JsonDict] = {}
|
||||
|
||||
# Holds the next_batch for the entire result set if one of those exists
|
||||
global_next_batch = None
|
||||
|
@ -262,7 +262,7 @@ class SearchHandler(BaseHandler):
|
|||
s["results"].append(e.event_id)
|
||||
|
||||
elif order_by == "recent":
|
||||
room_events = [] # type: List[EventBase]
|
||||
room_events: List[EventBase] = []
|
||||
i = 0
|
||||
|
||||
pagination_token = batch_token
|
||||
|
|
|
@ -90,14 +90,14 @@ class SpaceSummaryHandler:
|
|||
room_queue = deque((_RoomQueueEntry(room_id, ()),))
|
||||
|
||||
# rooms we have already processed
|
||||
processed_rooms = set() # type: Set[str]
|
||||
processed_rooms: Set[str] = set()
|
||||
|
||||
# events we have already processed. We don't necessarily have their event ids,
|
||||
# so instead we key on (room id, state key)
|
||||
processed_events = set() # type: Set[Tuple[str, str]]
|
||||
processed_events: Set[Tuple[str, str]] = set()
|
||||
|
||||
rooms_result = [] # type: List[JsonDict]
|
||||
events_result = [] # type: List[JsonDict]
|
||||
rooms_result: List[JsonDict] = []
|
||||
events_result: List[JsonDict] = []
|
||||
|
||||
while room_queue and len(rooms_result) < MAX_ROOMS:
|
||||
queue_entry = room_queue.popleft()
|
||||
|
@ -272,10 +272,10 @@ class SpaceSummaryHandler:
|
|||
# the set of rooms that we should not walk further. Initialise it with the
|
||||
# excluded-rooms list; we will add other rooms as we process them so that
|
||||
# we do not loop.
|
||||
processed_rooms = set(exclude_rooms) # type: Set[str]
|
||||
processed_rooms: Set[str] = set(exclude_rooms)
|
||||
|
||||
rooms_result = [] # type: List[JsonDict]
|
||||
events_result = [] # type: List[JsonDict]
|
||||
rooms_result: List[JsonDict] = []
|
||||
events_result: List[JsonDict] = []
|
||||
|
||||
while room_queue and len(rooms_result) < MAX_ROOMS:
|
||||
room_id = room_queue.popleft()
|
||||
|
@ -353,7 +353,7 @@ class SpaceSummaryHandler:
|
|||
max_children = MAX_ROOMS_PER_SPACE
|
||||
|
||||
now = self._clock.time_msec()
|
||||
events_result = [] # type: List[JsonDict]
|
||||
events_result: List[JsonDict] = []
|
||||
for edge_event in itertools.islice(child_events, max_children):
|
||||
events_result.append(
|
||||
await self._event_serializer.serialize_event(
|
||||
|
|
|
@ -202,10 +202,10 @@ class SsoHandler:
|
|||
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
|
||||
|
||||
# a map from session id to session data
|
||||
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
|
||||
self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {}
|
||||
|
||||
# map from idp_id to SsoIdentityProvider
|
||||
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
|
||||
self._identity_providers: Dict[str, SsoIdentityProvider] = {}
|
||||
|
||||
self._consent_at_registration = hs.config.consent.user_consent_at_registration
|
||||
|
||||
|
@ -296,7 +296,7 @@ class SsoHandler:
|
|||
)
|
||||
|
||||
# if the client chose an IdP, use that
|
||||
idp = None # type: Optional[SsoIdentityProvider]
|
||||
idp: Optional[SsoIdentityProvider] = None
|
||||
if idp_id:
|
||||
idp = self._identity_providers.get(idp_id)
|
||||
if not idp:
|
||||
|
@ -669,9 +669,9 @@ class SsoHandler:
|
|||
remote_user_id,
|
||||
)
|
||||
|
||||
user_id_to_verify = await self._auth_handler.get_session_data(
|
||||
user_id_to_verify: str = await self._auth_handler.get_session_data(
|
||||
ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
|
||||
) # type: str
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
|
@ -793,7 +793,7 @@ class SsoHandler:
|
|||
session.use_display_name = use_display_name
|
||||
|
||||
emails_from_idp = set(session.emails)
|
||||
filtered_emails = set() # type: Set[str]
|
||||
filtered_emails: Set[str] = set()
|
||||
|
||||
# we iterate through the list rather than just building a set conjunction, so
|
||||
# that we can log attempts to use unknown addresses
|
||||
|
|
|
@ -49,7 +49,7 @@ class StatsHandler:
|
|||
self.stats_enabled = hs.config.stats_enabled
|
||||
|
||||
# The current position in the current_state_delta stream
|
||||
self.pos = None # type: Optional[int]
|
||||
self.pos: Optional[int] = None
|
||||
|
||||
# Guard to ensure we only process deltas one at a time
|
||||
self._is_processing = False
|
||||
|
@ -131,10 +131,10 @@ class StatsHandler:
|
|||
mapping from room/user ID to changes in the various fields.
|
||||
"""
|
||||
|
||||
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
||||
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
||||
room_to_stats_deltas: Dict[str, CounterType[str]] = {}
|
||||
user_to_stats_deltas: Dict[str, CounterType[str]] = {}
|
||||
|
||||
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
|
||||
room_to_state_updates: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for delta in deltas:
|
||||
typ = delta["type"]
|
||||
|
@ -164,7 +164,7 @@ class StatsHandler:
|
|||
)
|
||||
continue
|
||||
|
||||
event_content = {} # type: JsonDict
|
||||
event_content: JsonDict = {}
|
||||
|
||||
if event_id is not None:
|
||||
event = await self.store.get_event(event_id, allow_none=True)
|
||||
|
|
|
@ -279,12 +279,14 @@ class SyncHandler:
|
|||
self.state_store = self.storage.state
|
||||
|
||||
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
|
||||
self.lazy_loaded_members_cache = ExpiringCache(
|
||||
self.lazy_loaded_members_cache: ExpiringCache[
|
||||
Tuple[str, Optional[str]], LruCache[str, str]
|
||||
] = ExpiringCache(
|
||||
"lazy_loaded_members_cache",
|
||||
self.clock,
|
||||
max_len=0,
|
||||
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
|
||||
) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
|
||||
)
|
||||
|
||||
async def wait_for_sync_for_user(
|
||||
self,
|
||||
|
@ -441,7 +443,7 @@ class SyncHandler:
|
|||
)
|
||||
now_token = now_token.copy_and_replace("typing_key", typing_key)
|
||||
|
||||
ephemeral_by_room = {} # type: JsonDict
|
||||
ephemeral_by_room: JsonDict = {}
|
||||
|
||||
for event in typing:
|
||||
# we want to exclude the room_id from the event, but modifying the
|
||||
|
@ -503,7 +505,7 @@ class SyncHandler:
|
|||
# We check if there are any state events, if there are then we pass
|
||||
# all current state events to the filter_events function. This is to
|
||||
# ensure that we always include current state in the timeline
|
||||
current_state_ids = frozenset() # type: FrozenSet[str]
|
||||
current_state_ids: FrozenSet[str] = frozenset()
|
||||
if any(e.is_state() for e in recents):
|
||||
current_state_ids_map = await self.store.get_current_state_ids(
|
||||
room_id
|
||||
|
@ -784,9 +786,9 @@ class SyncHandler:
|
|||
def get_lazy_loaded_members_cache(
|
||||
self, cache_key: Tuple[str, Optional[str]]
|
||||
) -> LruCache[str, str]:
|
||||
cache = self.lazy_loaded_members_cache.get(
|
||||
cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get(
|
||||
cache_key
|
||||
) # type: Optional[LruCache[str, str]]
|
||||
)
|
||||
if cache is None:
|
||||
logger.debug("creating LruCache for %r", cache_key)
|
||||
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
|
||||
|
@ -985,7 +987,7 @@ class SyncHandler:
|
|||
if t[0] == EventTypes.Member:
|
||||
cache.set(t[1], event_id)
|
||||
|
||||
state = {} # type: Dict[str, EventBase]
|
||||
state: Dict[str, EventBase] = {}
|
||||
if state_ids:
|
||||
state = await self.store.get_events(list(state_ids.values()))
|
||||
|
||||
|
@ -1089,8 +1091,8 @@ class SyncHandler:
|
|||
|
||||
logger.debug("Fetching OTK data")
|
||||
device_id = sync_config.device_id
|
||||
one_time_key_counts = {} # type: JsonDict
|
||||
unused_fallback_key_types = [] # type: List[str]
|
||||
one_time_key_counts: JsonDict = {}
|
||||
unused_fallback_key_types: List[str] = []
|
||||
if device_id:
|
||||
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
||||
user_id, device_id
|
||||
|
@ -1438,7 +1440,7 @@ class SyncHandler:
|
|||
)
|
||||
|
||||
if block_all_room_ephemeral:
|
||||
ephemeral_by_room = {} # type: Dict[str, List[JsonDict]]
|
||||
ephemeral_by_room: Dict[str, List[JsonDict]] = {}
|
||||
else:
|
||||
now_token, ephemeral_by_room = await self.ephemeral_by_room(
|
||||
sync_result_builder,
|
||||
|
@ -1469,7 +1471,7 @@ class SyncHandler:
|
|||
|
||||
# If there is ignored users account data and it matches the proper type,
|
||||
# then use it.
|
||||
ignored_users = frozenset() # type: FrozenSet[str]
|
||||
ignored_users: FrozenSet[str] = frozenset()
|
||||
if ignored_account_data:
|
||||
ignored_users_data = ignored_account_data.get("ignored_users", {})
|
||||
if isinstance(ignored_users_data, dict):
|
||||
|
@ -1587,7 +1589,7 @@ class SyncHandler:
|
|||
user_id, since_token.room_key, now_token.room_key
|
||||
)
|
||||
|
||||
mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]]
|
||||
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
|
||||
for event in rooms_changed:
|
||||
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
|
||||
|
||||
|
@ -1600,7 +1602,7 @@ class SyncHandler:
|
|||
logger.debug(
|
||||
"Membership changes in %s: [%s]",
|
||||
room_id,
|
||||
", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)),
|
||||
", ".join("%s (%s)" % (e.event_id, e.membership) for e in events),
|
||||
)
|
||||
|
||||
non_joins = [e for e in events if e.membership != Membership.JOIN]
|
||||
|
@ -1723,7 +1725,7 @@ class SyncHandler:
|
|||
# This is all screaming out for a refactor, as the logic here is
|
||||
# subtle and the moving parts numerous.
|
||||
if leave_event.internal_metadata.is_out_of_band_membership():
|
||||
batch_events = [leave_event] # type: Optional[List[EventBase]]
|
||||
batch_events: Optional[List[EventBase]] = [leave_event]
|
||||
else:
|
||||
batch_events = None
|
||||
|
||||
|
@ -1972,7 +1974,7 @@ class SyncHandler:
|
|||
room_id, batch, sync_config, since_token, now_token, full_state=full_state
|
||||
)
|
||||
|
||||
summary = {} # type: Optional[JsonDict]
|
||||
summary: Optional[JsonDict] = {}
|
||||
|
||||
# we include a summary in room responses when we're lazy loading
|
||||
# members (as the client otherwise doesn't have enough info to form
|
||||
|
@ -1996,7 +1998,7 @@ class SyncHandler:
|
|||
)
|
||||
|
||||
if room_builder.rtype == "joined":
|
||||
unread_notifications = {} # type: Dict[str, int]
|
||||
unread_notifications: Dict[str, int] = {}
|
||||
room_sync = JoinedSyncResult(
|
||||
room_id=room_id,
|
||||
timeline=batch,
|
||||
|
|
|
@ -68,11 +68,11 @@ class FollowerTypingHandler:
|
|||
)
|
||||
|
||||
# map room IDs to serial numbers
|
||||
self._room_serials = {} # type: Dict[str, int]
|
||||
self._room_serials: Dict[str, int] = {}
|
||||
# map room IDs to sets of users currently typing
|
||||
self._room_typing = {} # type: Dict[str, Set[str]]
|
||||
self._room_typing: Dict[str, Set[str]] = {}
|
||||
|
||||
self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
|
||||
self._member_last_federation_poke: Dict[RoomMember, int] = {}
|
||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||
self._latest_room_serial = 0
|
||||
|
||||
|
@ -217,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
# clock time we expect to stop
|
||||
self._member_typing_until = {} # type: Dict[RoomMember, int]
|
||||
self._member_typing_until: Dict[RoomMember, int] = {}
|
||||
|
||||
# caches which room_ids changed at which serials
|
||||
self._typing_stream_change_cache = StreamChangeCache(
|
||||
|
@ -405,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
|
||||
last_id
|
||||
) # type: Optional[Iterable[str]]
|
||||
changed_rooms: Optional[
|
||||
Iterable[str]
|
||||
] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
|
||||
|
||||
if changed_rooms is None:
|
||||
changed_rooms = self._room_serials
|
||||
|
|
|
@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
self.search_all_users = hs.config.user_directory_search_all_users
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
# The current position in the current_state_delta stream
|
||||
self.pos = None # type: Optional[int]
|
||||
self.pos: Optional[int] = None
|
||||
|
||||
# Guard to ensure we only process deltas one at a time
|
||||
self._is_processing = False
|
||||
|
|
|
@ -172,7 +172,7 @@ class ProxyAgent(_AgentBase):
|
|||
"""
|
||||
uri = uri.strip()
|
||||
if not _VALID_URI.match(uri):
|
||||
raise ValueError("Invalid URI {!r}".format(uri))
|
||||
raise ValueError(f"Invalid URI {uri!r}")
|
||||
|
||||
parsed_uri = URI.fromBytes(uri)
|
||||
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
|
||||
|
|
|
@ -384,7 +384,7 @@ class SynapseRequest(Request):
|
|||
# authenticated (e.g. and admin is puppetting a user) then we log both.
|
||||
requester, authenticated_entity = self.get_authenticated_entity()
|
||||
if authenticated_entity:
|
||||
requester = "{}.{}".format(authenticated_entity, requester)
|
||||
requester = f"{authenticated_entity}.{requester}"
|
||||
|
||||
self.site.access_logger.log(
|
||||
log_level,
|
||||
|
|
|
@ -374,7 +374,7 @@ def init_tracer(hs: "HomeServer"):
|
|||
|
||||
config = JaegerConfig(
|
||||
config=hs.config.jaeger_config,
|
||||
service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
|
||||
service_name=f"{hs.config.server_name} {hs.get_instance_name()}",
|
||||
scope_manager=LogContextScopeManager(hs.config),
|
||||
metrics_factory=PrometheusMetricsFactory(),
|
||||
)
|
||||
|
|
|
@ -34,7 +34,7 @@ from twisted.web.resource import Resource
|
|||
|
||||
from synapse.util import caches
|
||||
|
||||
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
|
||||
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
|
||||
|
||||
|
||||
INF = float("inf")
|
||||
|
@ -55,8 +55,8 @@ def floatToGoString(d):
|
|||
# Go switches to exponents sooner than Python.
|
||||
# We only need to care about positive values for le/quantile.
|
||||
if d > 0 and dot > 6:
|
||||
mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.")
|
||||
return "{0}e+0{1}".format(mantissa, dot - 1)
|
||||
mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.")
|
||||
return f"{mantissa}e+0{dot - 1}"
|
||||
return s
|
||||
|
||||
|
||||
|
@ -65,7 +65,7 @@ def sample_line(line, name):
|
|||
labelstr = "{{{0}}}".format(
|
||||
",".join(
|
||||
[
|
||||
'{0}="{1}"'.format(
|
||||
'{}="{}"'.format(
|
||||
k,
|
||||
v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""),
|
||||
)
|
||||
|
@ -78,10 +78,8 @@ def sample_line(line, name):
|
|||
timestamp = ""
|
||||
if line.timestamp is not None:
|
||||
# Convert to milliseconds.
|
||||
timestamp = " {0:d}".format(int(float(line.timestamp) * 1000))
|
||||
return "{0}{1} {2}{3}\n".format(
|
||||
name, labelstr, floatToGoString(line.value), timestamp
|
||||
)
|
||||
timestamp = f" {int(float(line.timestamp) * 1000):d}"
|
||||
return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
|
||||
|
||||
|
||||
def generate_latest(registry, emit_help=False):
|
||||
|
@ -118,12 +116,12 @@ def generate_latest(registry, emit_help=False):
|
|||
# Output in the old format for compatibility.
|
||||
if emit_help:
|
||||
output.append(
|
||||
"# HELP {0} {1}\n".format(
|
||||
"# HELP {} {}\n".format(
|
||||
mname,
|
||||
metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
|
||||
)
|
||||
)
|
||||
output.append("# TYPE {0} {1}\n".format(mname, mtype))
|
||||
output.append(f"# TYPE {mname} {mtype}\n")
|
||||
|
||||
om_samples: Dict[str, List[str]] = {}
|
||||
for s in metric.samples:
|
||||
|
@ -143,13 +141,13 @@ def generate_latest(registry, emit_help=False):
|
|||
for suffix, lines in sorted(om_samples.items()):
|
||||
if emit_help:
|
||||
output.append(
|
||||
"# HELP {0}{1} {2}\n".format(
|
||||
"# HELP {}{} {}\n".format(
|
||||
metric.name,
|
||||
suffix,
|
||||
metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
|
||||
)
|
||||
)
|
||||
output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix))
|
||||
output.append(f"# TYPE {metric.name}{suffix} gauge\n")
|
||||
output.extend(lines)
|
||||
|
||||
# Get rid of the weird colon things while we're at it
|
||||
|
@ -163,12 +161,12 @@ def generate_latest(registry, emit_help=False):
|
|||
# Also output in the new format, if it's different.
|
||||
if emit_help:
|
||||
output.append(
|
||||
"# HELP {0} {1}\n".format(
|
||||
"# HELP {} {}\n".format(
|
||||
mnewname,
|
||||
metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
|
||||
)
|
||||
)
|
||||
output.append("# TYPE {0} {1}\n".format(mnewname, mtype))
|
||||
output.append(f"# TYPE {mnewname} {mtype}\n")
|
||||
|
||||
for s in metric.samples:
|
||||
# Get rid of the OpenMetrics specific samples (we should already have
|
||||
|
|
|
@ -137,8 +137,7 @@ class _Collector:
|
|||
_background_process_db_txn_duration,
|
||||
_background_process_db_sched_duration,
|
||||
):
|
||||
for r in m.collect():
|
||||
yield r
|
||||
yield from m.collect()
|
||||
|
||||
|
||||
REGISTRY.register(_Collector())
|
||||
|
|
|
@ -12,18 +12,42 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import email.utils
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import jinja2
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import IResource
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.server import (
|
||||
DirectServeHtmlResource,
|
||||
DirectServeJsonResource,
|
||||
respond_with_html,
|
||||
)
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.roommember import ProfileInfo
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -33,7 +57,20 @@ This package defines the 'stable' API which can be used by extension modules whi
|
|||
are loaded into Synapse.
|
||||
"""
|
||||
|
||||
__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"]
|
||||
__all__ = [
|
||||
"errors",
|
||||
"make_deferred_yieldable",
|
||||
"parse_json_object_from_request",
|
||||
"respond_with_html",
|
||||
"run_in_background",
|
||||
"cached",
|
||||
"UserID",
|
||||
"DatabasePool",
|
||||
"LoggingTransaction",
|
||||
"DirectServeHtmlResource",
|
||||
"DirectServeJsonResource",
|
||||
"ModuleApi",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,12 +89,28 @@ class ModuleApi:
|
|||
self._server_name = hs.hostname
|
||||
self._presence_stream = hs.get_event_sources().sources["presence"]
|
||||
self._state = hs.get_state_handler()
|
||||
self._clock: Clock = hs.get_clock()
|
||||
self._send_email_handler = hs.get_send_email_handler()
|
||||
|
||||
try:
|
||||
app_name = self._hs.config.email_app_name
|
||||
|
||||
self._from_string = self._hs.config.email_notif_from % {"app": app_name}
|
||||
except (KeyError, TypeError):
|
||||
# If substitution failed (which can happen if the string contains
|
||||
# placeholders other than just "app", or if the type of the placeholder is
|
||||
# not a string), fall back to the bare strings.
|
||||
self._from_string = self._hs.config.email_notif_from
|
||||
|
||||
self._raw_from = email.utils.parseaddr(self._from_string)[1]
|
||||
|
||||
# We expose these as properties below in order to attach a helpful docstring.
|
||||
self._http_client: SimpleHttpClient = hs.get_simple_http_client()
|
||||
self._public_room_list_manager = PublicRoomListManager(hs)
|
||||
|
||||
self._spam_checker = hs.get_spam_checker()
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
self._third_party_event_rules = hs.get_third_party_event_rules()
|
||||
|
||||
#################################################################################
|
||||
# The following methods should only be called during the module's initialisation.
|
||||
|
@ -67,6 +120,16 @@ class ModuleApi:
|
|||
"""Registers callbacks for spam checking capabilities."""
|
||||
return self._spam_checker.register_callbacks
|
||||
|
||||
@property
|
||||
def register_account_validity_callbacks(self):
|
||||
"""Registers callbacks for account validity capabilities."""
|
||||
return self._account_validity_handler.register_account_validity_callbacks
|
||||
|
||||
@property
|
||||
def register_third_party_rules_callbacks(self):
|
||||
"""Registers callbacks for third party event rules capabilities."""
|
||||
return self._third_party_event_rules.register_third_party_rules_callbacks
|
||||
|
||||
def register_web_resource(self, path: str, resource: IResource):
|
||||
"""Registers a web resource to be served at the given path.
|
||||
|
||||
|
@ -101,22 +164,56 @@ class ModuleApi:
|
|||
"""
|
||||
return self._public_room_list_manager
|
||||
|
||||
def get_user_by_req(self, req, allow_guest=False):
|
||||
@property
|
||||
def public_baseurl(self) -> str:
|
||||
"""The configured public base URL for this homeserver."""
|
||||
return self._hs.config.public_baseurl
|
||||
|
||||
@property
|
||||
def email_app_name(self) -> str:
|
||||
"""The application name configured in the homeserver's configuration."""
|
||||
return self._hs.config.email.email_app_name
|
||||
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
req: SynapseRequest,
|
||||
allow_guest: bool = False,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
"""Check the access_token provided for a request
|
||||
|
||||
Args:
|
||||
req (twisted.web.server.Request): Incoming HTTP request
|
||||
allow_guest (bool): True if guest users should be allowed. If this
|
||||
req: Incoming HTTP request
|
||||
allow_guest: True if guest users should be allowed. If this
|
||||
is False, and the access token is for a guest user, an
|
||||
AuthError will be thrown
|
||||
allow_expired: True if expired users should be allowed. If this
|
||||
is False, and the access token is for an expired user, an
|
||||
AuthError will be thrown
|
||||
|
||||
Returns:
|
||||
twisted.internet.defer.Deferred[synapse.types.Requester]:
|
||||
the requester for this request
|
||||
The requester for this request
|
||||
|
||||
Raises:
|
||||
synapse.api.errors.AuthError: if no user by that token exists,
|
||||
InvalidClientCredentialsError: if no user by that token exists,
|
||||
or the token is invalid.
|
||||
"""
|
||||
return self._auth.get_user_by_req(req, allow_guest)
|
||||
return await self._auth.get_user_by_req(
|
||||
req,
|
||||
allow_guest,
|
||||
allow_expired=allow_expired,
|
||||
)
|
||||
|
||||
async def is_user_admin(self, user_id: str) -> bool:
|
||||
"""Checks if a user is a server admin.
|
||||
|
||||
Args:
|
||||
user_id: The Matrix ID of the user to check.
|
||||
|
||||
Returns:
|
||||
True if the user is a server admin, False otherwise.
|
||||
"""
|
||||
return await self._store.is_server_admin(UserID.from_string(user_id))
|
||||
|
||||
def get_qualified_user_id(self, username):
|
||||
"""Qualify a user id, if necessary
|
||||
|
@ -134,6 +231,32 @@ class ModuleApi:
|
|||
return username
|
||||
return UserID(username, self._hs.hostname).to_string()
|
||||
|
||||
async def get_profile_for_user(self, localpart: str) -> ProfileInfo:
|
||||
"""Look up the profile info for the user with the given localpart.
|
||||
|
||||
Args:
|
||||
localpart: The localpart to look up profile information for.
|
||||
|
||||
Returns:
|
||||
The profile information (i.e. display name and avatar URL).
|
||||
"""
|
||||
return await self._store.get_profileinfo(localpart)
|
||||
|
||||
async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
|
||||
"""Look up the threepids (email addresses and phone numbers) associated with the
|
||||
given Matrix user ID.
|
||||
|
||||
Args:
|
||||
user_id: The Matrix user ID to look up threepids for.
|
||||
|
||||
Returns:
|
||||
A list of threepids, each threepid being represented by a dictionary
|
||||
containing a "medium" key which value is "email" for email addresses and
|
||||
"msisdn" for phone numbers, and an "address" key which value is the
|
||||
threepid's address.
|
||||
"""
|
||||
return await self._store.user_get_threepids(user_id)
|
||||
|
||||
def check_user_exists(self, user_id):
|
||||
"""Check if user exists.
|
||||
|
||||
|
@ -464,6 +587,88 @@ class ModuleApi:
|
|||
presence_events, destination
|
||||
)
|
||||
|
||||
def looping_background_call(
|
||||
self,
|
||||
f: Callable,
|
||||
msec: float,
|
||||
*args,
|
||||
desc: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Wraps a function as a background process and calls it repeatedly.
|
||||
|
||||
Waits `msec` initially before calling `f` for the first time.
|
||||
|
||||
Args:
|
||||
f: The function to call repeatedly. f can be either synchronous or
|
||||
asynchronous, and must follow Synapse's logcontext rules.
|
||||
More info about logcontexts is available at
|
||||
https://matrix-org.github.io/synapse/latest/log_contexts.html
|
||||
msec: How long to wait between calls in milliseconds.
|
||||
*args: Positional arguments to pass to function.
|
||||
desc: The background task's description. Default to the function's name.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
if desc is None:
|
||||
desc = f.__name__
|
||||
|
||||
if self._hs.config.run_background_tasks:
|
||||
self._clock.looping_call(
|
||||
run_as_background_process,
|
||||
msec,
|
||||
desc,
|
||||
f,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Not running looping call %s as the configuration forbids it",
|
||||
f,
|
||||
)
|
||||
|
||||
async def send_mail(
|
||||
self,
|
||||
recipient: str,
|
||||
subject: str,
|
||||
html: str,
|
||||
text: str,
|
||||
):
|
||||
"""Send an email on behalf of the homeserver.
|
||||
|
||||
Args:
|
||||
recipient: The email address for the recipient.
|
||||
subject: The email's subject.
|
||||
html: The email's HTML content.
|
||||
text: The email's text content.
|
||||
"""
|
||||
await self._send_email_handler.send_email(
|
||||
email_address=recipient,
|
||||
subject=subject,
|
||||
app_name=self.email_app_name,
|
||||
html=html,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def read_templates(
|
||||
self,
|
||||
filenames: List[str],
|
||||
custom_template_directory: Optional[str] = None,
|
||||
) -> List[jinja2.Template]:
|
||||
"""Read and load the content of the template files at the given location.
|
||||
By default, Synapse will look for these templates in its configured template
|
||||
directory, but another directory to search in can be provided.
|
||||
|
||||
Args:
|
||||
filenames: The name of the template files to look for.
|
||||
custom_template_directory: An additional directory to look for the files in.
|
||||
|
||||
Returns:
|
||||
A list containing the loaded templates, with the orders matching the one of
|
||||
the filenames parameter.
|
||||
"""
|
||||
return self._hs.config.read_templates(filenames, custom_template_directory)
|
||||
|
||||
|
||||
class PublicRoomListManager:
|
||||
"""Contains methods for adding to, removing from and querying whether a room
|
||||
|
|
|
@ -14,5 +14,9 @@
|
|||
|
||||
"""Exception types which are exposed as part of the stable module API"""
|
||||
|
||||
from synapse.api.errors import RedirectException, SynapseError # noqa: F401
|
||||
from synapse.api.errors import ( # noqa: F401
|
||||
InvalidClientCredentialsError,
|
||||
RedirectException,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.config._base import ConfigError # noqa: F401
|
||||
|
|
|
@ -62,10 +62,6 @@ class PusherPool:
|
|||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
self._account_validity_enabled = (
|
||||
hs.config.account_validity.account_validity_enabled
|
||||
)
|
||||
|
||||
# We shard the handling of push notifications by user ID.
|
||||
self._pusher_shard_config = hs.config.push.pusher_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
@ -89,6 +85,8 @@ class PusherPool:
|
|||
# map from user id to app_id:pushkey to pusher
|
||||
self.pushers: Dict[str, Dict[str, Pusher]] = {}
|
||||
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Starts the pushers off in a background process."""
|
||||
if not self._should_start_pushers:
|
||||
|
@ -238,12 +236,9 @@ class PusherPool:
|
|||
|
||||
for u in users_affected:
|
||||
# Don't push if the user account has expired
|
||||
if self._account_validity_enabled:
|
||||
expired = await self.store.is_account_expired(
|
||||
u, self.clock.time_msec()
|
||||
)
|
||||
if expired:
|
||||
continue
|
||||
expired = await self._account_validity_handler.is_user_expired(u)
|
||||
if expired:
|
||||
continue
|
||||
|
||||
if u in self.pushers:
|
||||
for p in self.pushers[u].values():
|
||||
|
@ -268,12 +263,9 @@ class PusherPool:
|
|||
|
||||
for u in users_affected:
|
||||
# Don't push if the user account has expired
|
||||
if self._account_validity_enabled:
|
||||
expired = await self.store.is_account_expired(
|
||||
u, self.clock.time_msec()
|
||||
)
|
||||
if expired:
|
||||
continue
|
||||
expired = await self._account_validity_handler.is_user_expired(u)
|
||||
if expired:
|
||||
continue
|
||||
|
||||
if u in self.pushers:
|
||||
for p in self.pushers[u].values():
|
||||
|
|
|
@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
|||
|
||||
# Get the room ID from the identifier.
|
||||
try:
|
||||
remote_room_hosts = [
|
||||
remote_room_hosts: Optional[List[str]] = [
|
||||
x.decode("ascii") for x in request.args[b"server_name"]
|
||||
] # type: Optional[List[str]]
|
||||
]
|
||||
except Exception:
|
||||
remote_room_hosts = None
|
||||
room_id, remote_room_hosts = await self.resolve_room_id(
|
||||
|
@ -659,9 +659,7 @@ class RoomEventContextServlet(RestServlet):
|
|||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter = Filter(
|
||||
json_decoder.decode(filter_json)
|
||||
) # type: Optional[Filter]
|
||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
|
|
|
@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.reactor = hs.get_reactor()
|
||||
self.nonces = {} # type: Dict[str, int]
|
||||
self.nonces: Dict[str, int] = {}
|
||||
self.hs = hs
|
||||
|
||||
def _clear_old_nonces(self):
|
||||
|
@ -560,16 +560,24 @@ class AccountValidityRenewServlet(RestServlet):
|
|||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
if self.account_activity_handler.on_legacy_admin_request_callback:
|
||||
expiration_ts = await (
|
||||
self.account_activity_handler.on_legacy_admin_request_callback(request)
|
||||
)
|
||||
else:
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if "user_id" not in body:
|
||||
raise SynapseError(400, "Missing property 'user_id' in the request body")
|
||||
if "user_id" not in body:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing property 'user_id' in the request body",
|
||||
)
|
||||
|
||||
expiration_ts = await self.account_activity_handler.renew_account_for_user(
|
||||
body["user_id"],
|
||||
body.get("expiration_ts"),
|
||||
not body.get("enable_renewal_emails", True),
|
||||
)
|
||||
expiration_ts = await self.account_activity_handler.renew_account_for_user(
|
||||
body["user_id"],
|
||||
body.get("expiration_ts"),
|
||||
not body.get("enable_renewal_emails", True),
|
||||
)
|
||||
|
||||
res = {"expiration_ts": expiration_ts}
|
||||
return 200, res
|
||||
|
|
|
@ -44,19 +44,14 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LoginResponse = TypedDict(
|
||||
"LoginResponse",
|
||||
{
|
||||
"user_id": str,
|
||||
"access_token": str,
|
||||
"home_server": str,
|
||||
"expires_in_ms": Optional[int],
|
||||
"refresh_token": Optional[str],
|
||||
"device_id": str,
|
||||
"well_known": Optional[Dict[str, Any]],
|
||||
},
|
||||
total=False,
|
||||
)
|
||||
class LoginResponse(TypedDict, total=False):
|
||||
user_id: str
|
||||
access_token: str
|
||||
home_server: str
|
||||
expires_in_ms: Optional[int]
|
||||
refresh_token: Optional[str]
|
||||
device_id: str
|
||||
well_known: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
|
@ -121,7 +116,7 @@ class LoginRestServlet(RestServlet):
|
|||
flows.append({"type": LoginRestServlet.CAS_TYPE})
|
||||
|
||||
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
|
||||
sso_flow = {
|
||||
sso_flow: JsonDict = {
|
||||
"type": LoginRestServlet.SSO_TYPE,
|
||||
"identity_providers": [
|
||||
_get_auth_flow_dict_for_idp(
|
||||
|
@ -129,7 +124,7 @@ class LoginRestServlet(RestServlet):
|
|||
)
|
||||
for idp in self._sso_handler.get_identity_providers().values()
|
||||
],
|
||||
} # type: JsonDict
|
||||
}
|
||||
|
||||
if self._msc2858_enabled:
|
||||
# backwards-compatibility support for clients which don't
|
||||
|
@ -150,9 +145,7 @@ class LoginRestServlet(RestServlet):
|
|||
# login flow types returned.
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
|
||||
flows.extend(
|
||||
({"type": t} for t in self.auth_handler.get_supported_login_types())
|
||||
)
|
||||
flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
|
||||
|
||||
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
|
||||
|
||||
|
@ -447,7 +440,7 @@ def _get_auth_flow_dict_for_idp(
|
|||
use_unstable_brands: whether we should use brand identifiers suitable
|
||||
for the unstable API
|
||||
"""
|
||||
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
|
||||
e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
|
||||
if idp.idp_icon:
|
||||
e["icon"] = idp.idp_icon
|
||||
if idp.idp_brand:
|
||||
|
@ -561,7 +554,7 @@ class SsoRedirectServlet(RestServlet):
|
|||
finish_request(request)
|
||||
return
|
||||
|
||||
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
|
||||
args: Dict[bytes, List[bytes]] = request.args # type: ignore
|
||||
client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
|
||||
sso_url = await self._sso_handler.handle_redirect_request(
|
||||
request,
|
||||
|
|
|
@ -783,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||
server = parse_string(request, "server", default=None)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
limit = int(content.get("limit", 100)) # type: Optional[int]
|
||||
limit: Optional[int] = int(content.get("limit", 100))
|
||||
since_token = content.get("since", None)
|
||||
search_filter = content.get("filter", None)
|
||||
|
||||
|
@ -929,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter = Filter(
|
||||
json_decoder.decode(filter_json)
|
||||
) # type: Optional[Filter]
|
||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||
if (
|
||||
event_filter
|
||||
and event_filter.filter_json.get("event_format", "client")
|
||||
|
@ -1044,9 +1042,7 @@ class RoomEventContextServlet(RestServlet):
|
|||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter = Filter(
|
||||
json_decoder.decode(filter_json)
|
||||
) # type: Optional[Filter]
|
||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
|
@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet):
|
|||
)
|
||||
|
||||
async def on_POST(self, request):
|
||||
if not self.account_validity_renew_by_email_enabled:
|
||||
raise AuthError(
|
||||
403, "Account renewal via email is disabled on this server."
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request, allow_expired=True)
|
||||
user_id = requester.user.to_string()
|
||||
await self.account_activity_handler.send_renewal_email_to_user(user_id)
|
||||
|
|
|
@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
|||
requester, message_type, content["messages"]
|
||||
)
|
||||
|
||||
response = (200, {}) # type: Tuple[int, dict]
|
||||
response: Tuple[int, dict] = (200, {})
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ class ConsentResource(DirectServeHtmlResource):
|
|||
has_consented = False
|
||||
public_version = username == ""
|
||||
if not public_version:
|
||||
args = request.args # type: Dict[bytes, List[bytes]]
|
||||
args: Dict[bytes, List[bytes]] = request.args
|
||||
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
|
||||
|
||||
self._check_hash(username, userhmac_bytes)
|
||||
|
@ -154,7 +154,7 @@ class ConsentResource(DirectServeHtmlResource):
|
|||
"""
|
||||
version = parse_string(request, "v", required=True)
|
||||
username = parse_string(request, "u", required=True)
|
||||
args = request.args # type: Dict[bytes, List[bytes]]
|
||||
args: Dict[bytes, List[bytes]] = request.args
|
||||
userhmac = parse_bytes_from_args(args, "h", required=True)
|
||||
|
||||
self._check_hash(username, userhmac)
|
||||
|
|
|
@ -97,7 +97,7 @@ class RemoteKey(DirectServeJsonResource):
|
|||
async def _async_render_GET(self, request):
|
||||
if len(request.postpath) == 1:
|
||||
(server,) = request.postpath
|
||||
query = {server.decode("ascii"): {}} # type: dict
|
||||
query: dict = {server.decode("ascii"): {}}
|
||||
elif len(request.postpath) == 2:
|
||||
server, key_id = request.postpath
|
||||
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
|
||||
|
@ -141,7 +141,7 @@ class RemoteKey(DirectServeJsonResource):
|
|||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
# Note that the value is unused.
|
||||
cache_misses = {} # type: Dict[str, Dict[str, int]]
|
||||
cache_misses: Dict[str, Dict[str, int]] = {}
|
||||
for (server_name, key_id, _), results in cached.items():
|
||||
results = [(result["ts_added_ms"], result) for result in results]
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import PIL.Image
|
|||
# check for JPEG support.
|
||||
try:
|
||||
PIL.Image._getdecoder("rgb", "jpeg", None)
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
if str(e).startswith("decoder jpeg not available"):
|
||||
raise Exception(
|
||||
"FATAL: jpeg codec not supported. Install pillow correctly! "
|
||||
|
@ -32,7 +32,7 @@ except Exception:
|
|||
# check for PNG support.
|
||||
try:
|
||||
PIL.Image._getdecoder("rgb", "zip", None)
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
if str(e).startswith("decoder zip not available"):
|
||||
raise Exception(
|
||||
"FATAL: zip codec not supported. Install pillow correctly! "
|
||||
|
|
|
@ -49,7 +49,7 @@ TEXT_CONTENT_TYPES = [
|
|||
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
||||
try:
|
||||
# The type on postpath seems incorrect in Twisted 21.2.0.
|
||||
postpath = request.postpath # type: List[bytes] # type: ignore
|
||||
postpath: List[bytes] = request.postpath # type: ignore
|
||||
assert postpath
|
||||
|
||||
# This allows users to append e.g. /test.png to the URL. Useful for
|
||||
|
|
|
@ -78,16 +78,16 @@ class MediaRepository:
|
|||
|
||||
Thumbnailer.set_limits(self.max_image_pixels)
|
||||
|
||||
self.primary_base_path = hs.config.media_store_path # type: str
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
|
||||
self.primary_base_path: str = hs.config.media_store_path
|
||||
self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
|
||||
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
self.remote_media_linearizer = Linearizer(name="media_remote")
|
||||
|
||||
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
|
||||
self.recently_accessed_locals = set() # type: Set[str]
|
||||
self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
|
||||
self.recently_accessed_locals: Set[str] = set()
|
||||
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
|
||||
|
@ -711,7 +711,7 @@ class MediaRepository:
|
|||
|
||||
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
||||
# they have the same dimensions of a scaled one.
|
||||
thumbnails = {} # type: Dict[Tuple[int, int, str], str]
|
||||
thumbnails: Dict[Tuple[int, int, str], str] = {}
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "crop":
|
||||
thumbnails.setdefault((r_width, r_height, r_type), r_method)
|
||||
|
|
|
@ -191,7 +191,7 @@ class MediaStorage:
|
|||
|
||||
for provider in self.storage_providers:
|
||||
for path in paths:
|
||||
res = await provider.fetch(path, file_info) # type: Any
|
||||
res: Any = await provider.fetch(path, file_info)
|
||||
if res:
|
||||
logger.debug("Streaming %s from %s", path, provider)
|
||||
return res
|
||||
|
@ -233,7 +233,7 @@ class MediaStorage:
|
|||
os.makedirs(dirname)
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = await provider.fetch(path, file_info) # type: Any
|
||||
res: Any = await provider.fetch(path, file_info)
|
||||
if res:
|
||||
with res:
|
||||
consumer = BackgroundFileConsumer(
|
||||
|
|
|
@ -169,12 +169,12 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
|
||||
# memory cache mapping urls to an ObservableDeferred returning
|
||||
# JSON-encoded OG metadata
|
||||
self._cache = ExpiringCache(
|
||||
self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache(
|
||||
cache_name="url_previews",
|
||||
clock=self.clock,
|
||||
# don't spider URLs more often than once an hour
|
||||
expiry_ms=ONE_HOUR,
|
||||
) # type: ExpiringCache[str, ObservableDeferred]
|
||||
)
|
||||
|
||||
if self._worker_run_media_background_jobs:
|
||||
self._cleaner_loop = self.clock.looping_call(
|
||||
|
@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
||||
|
||||
# If this URL can be accessed via oEmbed, use that instead.
|
||||
url_to_download = url # type: Optional[str]
|
||||
url_to_download: Optional[str] = url
|
||||
oembed_url = self._get_oembed_url(url)
|
||||
if oembed_url:
|
||||
# The result might be a new URL to download, or it might be HTML content.
|
||||
|
@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
|
|||
# "og:video:height" : "720",
|
||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||
|
||||
og = {} # type: Dict[str, Optional[str]]
|
||||
og: Dict[str, Optional[str]] = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
if "content" in tag.attrib:
|
||||
# if we've got more than 50 tags, someone is taking the piss
|
||||
|
|
|
@ -61,11 +61,11 @@ class UploadResource(DirectServeJsonResource):
|
|||
errcode=Codes.TOO_LARGE,
|
||||
)
|
||||
|
||||
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
|
||||
args: Dict[bytes, List[bytes]] = request.args # type: ignore
|
||||
upload_name_bytes = parse_bytes_from_args(args, "filename")
|
||||
if upload_name_bytes:
|
||||
try:
|
||||
upload_name = upload_name_bytes.decode("utf8") # type: Optional[str]
|
||||
upload_name: Optional[str] = upload_name_bytes.decode("utf8")
|
||||
except UnicodeDecodeError:
|
||||
raise SynapseError(
|
||||
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
|
||||
|
@ -89,7 +89,7 @@ class UploadResource(DirectServeJsonResource):
|
|||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
try:
|
||||
content = request.content # type: IO # type: ignore
|
||||
content: IO = request.content # type: ignore
|
||||
content_uri = await self.media_repo.create_content(
|
||||
media_type, upload_name, content, content_length, requester.user
|
||||
)
|
||||
|
|
|
@ -118,9 +118,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
|
|||
use_display_name = parse_boolean(request, "use_display_name", default=False)
|
||||
|
||||
try:
|
||||
emails_to_use = [
|
||||
emails_to_use: List[str] = [
|
||||
val.decode("utf-8") for val in request.args.get(b"use_email", [])
|
||||
] # type: List[str]
|
||||
]
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Query parameter use_email must be utf-8")
|
||||
except SynapseError as e:
|
||||
|
|
|
@ -907,7 +907,7 @@ class DatabasePool:
|
|||
# The sort is to ensure that we don't rely on dictionary iteration
|
||||
# order.
|
||||
keys, vals = zip(
|
||||
*[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
|
||||
*(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i)
|
||||
)
|
||||
|
||||
for k in keys:
|
||||
|
|
|
@ -203,9 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
"delete_messages_for_device", delete_messages_for_device_txn
|
||||
)
|
||||
|
||||
log_kv(
|
||||
{"message": "deleted {} messages for device".format(count), "count": count}
|
||||
)
|
||||
log_kv({"message": f"deleted {count} messages for device", "count": count})
|
||||
|
||||
# Update the cache, ensuring that we only ever increase the value
|
||||
last_deleted_stream_id = self._last_device_delete_cache.get(
|
||||
|
|
|
@ -2010,10 +2010,6 @@ class PersistEventsStore:
|
|||
|
||||
Forward extremities are handled when we first start persisting the events.
|
||||
"""
|
||||
events_by_room: Dict[str, List[EventBase]] = {}
|
||||
for ev in events:
|
||||
events_by_room.setdefault(ev.room_id, []).append(ev)
|
||||
|
||||
query = (
|
||||
"INSERT INTO event_backward_extremities (event_id, room_id)"
|
||||
" SELECT ?, ? WHERE NOT EXISTS ("
|
||||
|
|
|
@ -27,8 +27,11 @@ from synapse.util import json_encoder
|
|||
_DEFAULT_CATEGORY_ID = ""
|
||||
_DEFAULT_ROLE_ID = ""
|
||||
|
||||
|
||||
# A room in a group.
|
||||
_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
|
||||
class _RoomInGroup(TypedDict):
|
||||
room_id: str
|
||||
is_public: bool
|
||||
|
||||
|
||||
class GroupServerWorkerStore(SQLBaseStore):
|
||||
|
@ -92,6 +95,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||
"is_public": False # Whether this is a public room or not
|
||||
}
|
||||
"""
|
||||
|
||||
# TODO: Pagination
|
||||
|
||||
def _get_rooms_in_group_txn(txn):
|
||||
|
|
|
@ -316,6 +316,135 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
|
||||
return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
|
||||
|
||||
async def count_r30v2_users(self) -> Dict[str, int]:
|
||||
"""
|
||||
Counts the number of 30 day retained users, defined as users that:
|
||||
- Appear more than once in the past 60 days
|
||||
- Have more than 30 days between the most and least recent appearances that
|
||||
occurred in the past 60 days.
|
||||
|
||||
(This is the second version of this metric, hence R30'v2')
|
||||
|
||||
Returns:
|
||||
A mapping from client type to the number of 30-day retained users for that client.
|
||||
|
||||
The dict keys are:
|
||||
- "all" (a combined number of users across any and all clients)
|
||||
- "android" (Element Android)
|
||||
- "ios" (Element iOS)
|
||||
- "electron" (Element Desktop)
|
||||
- "web" (any web application -- it's not possible to distinguish Element Web here)
|
||||
"""
|
||||
|
||||
def _count_r30v2_users(txn):
|
||||
thirty_days_in_secs = 86400 * 30
|
||||
now = int(self._clock.time())
|
||||
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
|
||||
one_day_from_now_in_secs = now + 86400
|
||||
|
||||
# This is the 'per-platform' count.
|
||||
sql = """
|
||||
SELECT
|
||||
client_type,
|
||||
count(client_type)
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
user_id,
|
||||
CASE
|
||||
WHEN
|
||||
LOWER(user_agent) LIKE '%%riot%%' OR
|
||||
LOWER(user_agent) LIKE '%%element%%'
|
||||
THEN CASE
|
||||
WHEN
|
||||
LOWER(user_agent) LIKE '%%electron%%'
|
||||
THEN 'electron'
|
||||
WHEN
|
||||
LOWER(user_agent) LIKE '%%android%%'
|
||||
THEN 'android'
|
||||
WHEN
|
||||
LOWER(user_agent) LIKE '%%ios%%'
|
||||
THEN 'ios'
|
||||
ELSE 'unknown'
|
||||
END
|
||||
WHEN
|
||||
LOWER(user_agent) LIKE '%%mozilla%%' OR
|
||||
LOWER(user_agent) LIKE '%%gecko%%'
|
||||
THEN 'web'
|
||||
ELSE 'unknown'
|
||||
END as client_type
|
||||
FROM
|
||||
user_daily_visits
|
||||
WHERE
|
||||
timestamp > ?
|
||||
AND
|
||||
timestamp < ?
|
||||
GROUP BY
|
||||
user_id,
|
||||
client_type
|
||||
HAVING
|
||||
max(timestamp) - min(timestamp) > ?
|
||||
) AS temp
|
||||
GROUP BY
|
||||
client_type
|
||||
;
|
||||
"""
|
||||
|
||||
# We initialise all the client types to zero, so we get an explicit
|
||||
# zero if they don't appear in the query results
|
||||
results = {"ios": 0, "android": 0, "web": 0, "electron": 0}
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
sixty_days_ago_in_secs * 1000,
|
||||
one_day_from_now_in_secs * 1000,
|
||||
thirty_days_in_secs * 1000,
|
||||
),
|
||||
)
|
||||
|
||||
for row in txn:
|
||||
if row[0] == "unknown":
|
||||
continue
|
||||
results[row[0]] = row[1]
|
||||
|
||||
# This is the 'all users' count.
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT
|
||||
1
|
||||
FROM
|
||||
user_daily_visits
|
||||
WHERE
|
||||
timestamp > ?
|
||||
AND
|
||||
timestamp < ?
|
||||
GROUP BY
|
||||
user_id
|
||||
HAVING
|
||||
max(timestamp) - min(timestamp) > ?
|
||||
) AS r30_users
|
||||
"""
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
sixty_days_ago_in_secs * 1000,
|
||||
one_day_from_now_in_secs * 1000,
|
||||
thirty_days_in_secs * 1000,
|
||||
),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
results["all"] = 0
|
||||
else:
|
||||
results["all"] = row[0]
|
||||
|
||||
return results
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_r30v2_users", _count_r30v2_users
|
||||
)
|
||||
|
||||
def _get_start_of_day(self):
|
||||
"""
|
||||
Returns millisecond unixtime for start of UTC day.
|
||||
|
|
|
@ -649,7 +649,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
event_to_memberships = await self._get_joined_profiles_from_event_ids(
|
||||
missing_member_event_ids
|
||||
)
|
||||
users_in_room.update((row for row in event_to_memberships.values() if row))
|
||||
users_in_room.update(row for row in event_to_memberships.values() if row)
|
||||
|
||||
if event is not None and event.type == EventTypes.Member:
|
||||
if event.membership == Membership.JOIN:
|
||||
|
|
|
@ -639,7 +639,7 @@ def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
|
|||
|
||||
|
||||
def executescript(txn: Cursor, schema_path: str) -> None:
|
||||
with open(schema_path, "r") as f:
|
||||
with open(schema_path) as f:
|
||||
execute_statements_from_stream(txn, f)
|
||||
|
||||
|
||||
|
|
|
@ -577,10 +577,10 @@ class RoomStreamToken:
|
|||
entries = []
|
||||
for name, pos in self.instance_map.items():
|
||||
instance_id = await store.get_id_for_instance(name)
|
||||
entries.append("{}.{}".format(instance_id, pos))
|
||||
entries.append(f"{instance_id}.{pos}")
|
||||
|
||||
encoded_map = "~".join(entries)
|
||||
return "m{}~{}".format(self.stream, encoded_map)
|
||||
return f"m{self.stream}~{encoded_map}"
|
||||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
|
|
|
@ -90,8 +90,7 @@ def enumerate_leaves(node, depth):
|
|||
yield node
|
||||
else:
|
||||
for n in node.values():
|
||||
for m in enumerate_leaves(n, depth - 1):
|
||||
yield m
|
||||
yield from enumerate_leaves(n, depth - 1)
|
||||
|
||||
|
||||
P = TypeVar("P")
|
||||
|
|
|
@ -138,7 +138,6 @@ def iterate_tree_cache_entry(d):
|
|||
"""
|
||||
if isinstance(d, TreeCacheNode):
|
||||
for value_d in d.values():
|
||||
for value in iterate_tree_cache_entry(value_d):
|
||||
yield value
|
||||
yield from iterate_tree_cache_entry(value_d)
|
||||
else:
|
||||
yield d
|
||||
|
|
|
@ -31,13 +31,13 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
|
|||
# If pidfile already exists, we should read pid from there; to overwrite it, if
|
||||
# locking will fail, because locking attempt somehow purges the file contents.
|
||||
if os.path.isfile(pid_file):
|
||||
with open(pid_file, "r") as pid_fh:
|
||||
with open(pid_file) as pid_fh:
|
||||
old_pid = pid_fh.read()
|
||||
|
||||
# Create a lockfile so that only one instance of this daemon is running at any time.
|
||||
try:
|
||||
lock_fh = open(pid_file, "w")
|
||||
except IOError:
|
||||
except OSError:
|
||||
print("Unable to create the pidfile.")
|
||||
sys.exit(1)
|
||||
|
||||
|
@ -45,7 +45,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
|
|||
# Try to get an exclusive lock on the file. This will fail if another process
|
||||
# has the file locked.
|
||||
fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
except IOError:
|
||||
except OSError:
|
||||
print("Unable to lock on the pidfile.")
|
||||
# We need to overwrite the pidfile if we got here.
|
||||
#
|
||||
|
@ -113,7 +113,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
|
|||
try:
|
||||
lock_fh.write("%s" % (os.getpid()))
|
||||
lock_fh.flush()
|
||||
except IOError:
|
||||
except OSError:
|
||||
logger.error("Unable to write pid to the pidfile.")
|
||||
print("Unable to write pid to the pidfile.")
|
||||
sys.exit(1)
|
||||
|
|
|
@ -96,7 +96,7 @@ async def filter_events_for_client(
|
|||
if isinstance(ignored_users_dict, dict):
|
||||
ignore_list = frozenset(ignored_users_dict.keys())
|
||||
|
||||
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
|
||||
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
|
||||
|
||||
if filter_send_to_client:
|
||||
room_ids = {e.room_id for e in events}
|
||||
|
@ -353,7 +353,7 @@ async def filter_events_for_server(
|
|||
)
|
||||
|
||||
if not check_history_visibility_only:
|
||||
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
|
||||
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
|
||||
else:
|
||||
# We don't want to check whether users are erased, which is equivalent
|
||||
# to no users having been erased.
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import synapse
|
||||
from synapse.app.phone_stats_home import start_phone_stats_home
|
||||
from synapse.rest.client.v1 import login, room
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
FIVE_MINUTES_IN_SECONDS = 300
|
||||
ONE_DAY_IN_SECONDS = 86400
|
||||
|
||||
|
||||
|
@ -151,3 +153,243 @@ class PhoneHomeTestCase(HomeserverTestCase):
|
|||
# *Now* the user appears in R30.
|
||||
r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
|
||||
self.assertEqual(r30_results, {"all": 1, "unknown": 1})
|
||||
|
||||
|
||||
class PhoneHomeR30V2TestCase(HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def _advance_to(self, desired_time_secs: float):
|
||||
now = self.hs.get_clock().time()
|
||||
assert now < desired_time_secs
|
||||
self.reactor.advance(desired_time_secs - now)
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
|
||||
|
||||
# We don't want our tests to actually report statistics, so check
|
||||
# that it's not enabled
|
||||
assert not hs.config.report_stats
|
||||
|
||||
# This starts the needed data collection that we rely on to calculate
|
||||
# R30v2 metrics.
|
||||
start_phone_stats_home(hs)
|
||||
return hs
|
||||
|
||||
def test_r30v2_minimum_usage(self):
|
||||
"""
|
||||
Tests the minimum amount of interaction necessary for the R30v2 metric
|
||||
to consider a user 'retained'.
|
||||
"""
|
||||
|
||||
# Register a user, log it in, create a room and send a message
|
||||
user_id = self.register_user("u1", "secret!")
|
||||
access_token = self.login("u1", "secret!")
|
||||
room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
|
||||
self.helper.send(room_id, "message", tok=access_token)
|
||||
first_post_at = self.hs.get_clock().time()
|
||||
|
||||
# Give time for user_daily_visits table to be updated.
|
||||
# (user_daily_visits is updated every 5 minutes using a looping call.)
|
||||
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
|
||||
# Check the R30 results do not count that user.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
||||
# Advance 31 days.
|
||||
# (R30v2 includes users with **more** than 30 days between the two visits,
|
||||
# and user_daily_visits records the timestamp as the start of the day.)
|
||||
self.reactor.advance(31 * ONE_DAY_IN_SECONDS)
|
||||
# Also advance 5 minutes to let another user_daily_visits update occur
|
||||
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
|
||||
|
||||
# (Make sure the user isn't somehow counted by this point.)
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
||||
# Send a message (this counts as activity)
|
||||
self.helper.send(room_id, "message2", tok=access_token)
|
||||
|
||||
# We have to wait a few minutes for the user_daily_visits table to
|
||||
# be updated by a background process.
|
||||
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
|
||||
|
||||
# *Now* the user is counted.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
||||
# Advance to JUST under 60 days after the user's first post
|
||||
self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS - 5)
|
||||
|
||||
# Check the user is still counted.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
||||
# Advance into the next day. The user's first activity is now more than 60 days old.
|
||||
self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS + 5)
|
||||
|
||||
# Check the user is now no longer counted in R30.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
||||
def test_r30v2_user_must_be_retained_for_at_least_a_month(self):
|
||||
"""
|
||||
Tests that a newly-registered user must be retained for a whole month
|
||||
before appearing in the R30v2 statistic, even if they post every day
|
||||
during that time!
|
||||
"""
|
||||
|
||||
# set a custom user-agent to impersonate Element/Android.
|
||||
headers = (
|
||||
(
|
||||
"User-Agent",
|
||||
"Element/1.1 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
|
||||
),
|
||||
)
|
||||
|
||||
# Register a user and send a message
|
||||
user_id = self.register_user("u1", "secret!")
|
||||
access_token = self.login("u1", "secret!", custom_headers=headers)
|
||||
room_id = self.helper.create_room_as(
|
||||
room_creator=user_id, tok=access_token, custom_headers=headers
|
||||
)
|
||||
self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
|
||||
|
||||
# Give time for user_daily_visits table to be updated.
|
||||
# (user_daily_visits is updated every 5 minutes using a looping call.)
|
||||
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
|
||||
# Check the user does not contribute to R30 yet.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
||||
for _ in range(30):
|
||||
# This loop posts a message every day for 30 days
|
||||
self.reactor.advance(ONE_DAY_IN_SECONDS - FIVE_MINUTES_IN_SECONDS)
|
||||
self.helper.send(
|
||||
room_id, "I'm still here", tok=access_token, custom_headers=headers
|
||||
)
|
||||
|
||||
# give time for user_daily_visits to update
|
||||
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
|
||||
|
||||
# Notice that the user *still* does not contribute to R30!
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
||||
# advance yet another day with more activity
|
||||
self.reactor.advance(ONE_DAY_IN_SECONDS)
|
||||
self.helper.send(
|
||||
room_id, "Still here!", tok=access_token, custom_headers=headers
|
||||
)
|
||||
|
||||
# give time for user_daily_visits to update
|
||||
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
|
||||
|
||||
# *Now* the user appears in R30.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
||||
def test_r30v2_returning_dormant_users_not_counted(self):
|
||||
"""
|
||||
Tests that dormant users (users inactive for a long time) do not
|
||||
contribute to R30v2 when they return for just a single day.
|
||||
This is a key difference between R30 and R30v2.
|
||||
"""
|
||||
|
||||
# set a custom user-agent to impersonate Element/iOS.
|
||||
headers = (
|
||||
(
|
||||
"User-Agent",
|
||||
"Riot/1.4 (iPhone; iOS 13; Scale/4.00)",
|
||||
),
|
||||
)
|
||||
|
||||
# Register a user and send a message
|
||||
user_id = self.register_user("u1", "secret!")
|
||||
access_token = self.login("u1", "secret!", custom_headers=headers)
|
||||
room_id = self.helper.create_room_as(
|
||||
room_creator=user_id, tok=access_token, custom_headers=headers
|
||||
)
|
||||
self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
|
||||
|
||||
# the user goes inactive for 2 months
|
||||
self.reactor.advance(60 * ONE_DAY_IN_SECONDS)
|
||||
|
||||
# the user returns for one day, perhaps just to check out a new feature
|
||||
self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
|
||||
|
||||
# Give time for user_daily_visits table to be updated.
|
||||
# (user_daily_visits is updated every 5 minutes using a looping call.)
|
||||
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
|
||||
# Check that the user does not contribute to R30v2, even though it's been
|
||||
# more than 30 days since registration.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
||||
# Check that this is a situation where old R30 differs:
|
||||
# old R30 DOES count this as 'retained'.
|
||||
r30_results = self.get_success(store.count_r30_users())
|
||||
self.assertEqual(r30_results, {"all": 1, "ios": 1})
|
||||
|
||||
# Now we want to check that the user will still be able to appear in
|
||||
# R30v2 as long as the user performs some other activity between
|
||||
# 30 and 60 days later.
|
||||
self.reactor.advance(32 * ONE_DAY_IN_SECONDS)
|
||||
self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
|
||||
|
||||
# (give time for tables to update)
|
||||
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
|
||||
|
||||
# Check the user now satisfies the requirements to appear in R30v2.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
|
||||
)
|
||||
|
||||
# Advance to 59.5 days after the user's first R30v2-eligible activity.
|
||||
self.reactor.advance(27.5 * ONE_DAY_IN_SECONDS)
|
||||
|
||||
# Check the user still appears in R30v2.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
|
||||
)
|
||||
|
||||
# Advance to 60.5 days after the user's first R30v2-eligible activity.
|
||||
self.reactor.advance(ONE_DAY_IN_SECONDS)
|
||||
|
||||
# Check the user no longer appears in R30v2.
|
||||
r30_results = self.get_success(store.count_r30v2_users())
|
||||
self.assertEqual(
|
||||
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
|
||||
)
|
||||
|
|
|
@ -16,17 +16,19 @@ from typing import Dict
|
|||
from unittest.mock import Mock
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.types import Requester, StateMap
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
|
||||
from tests import unittest
|
||||
|
||||
thread_local = threading.local()
|
||||
|
||||
|
||||
class ThirdPartyRulesTestModule:
|
||||
class LegacyThirdPartyRulesTestModule:
|
||||
def __init__(self, config: Dict, module_api: ModuleApi):
|
||||
# keep a record of the "current" rules module, so that the test can patch
|
||||
# it if desired.
|
||||
|
@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule:
|
|||
return config
|
||||
|
||||
|
||||
def current_rules_module() -> ThirdPartyRulesTestModule:
|
||||
return thread_local.rules_module
|
||||
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
|
||||
def __init__(self, config: Dict, module_api: ModuleApi):
|
||||
super().__init__(config, module_api)
|
||||
|
||||
def on_create_room(
|
||||
self, requester: Requester, config: dict, is_requester_admin: bool
|
||||
):
|
||||
return False
|
||||
|
||||
|
||||
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
|
||||
def __init__(self, config: Dict, module_api: ModuleApi):
|
||||
super().__init__(config, module_api)
|
||||
|
||||
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
|
||||
d = event.get_dict()
|
||||
content = unfreeze(event.content)
|
||||
content["foo"] = "bar"
|
||||
d["content"] = content
|
||||
return d
|
||||
|
||||
|
||||
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
||||
|
@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
|||
room.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["third_party_event_rules"] = {
|
||||
"module": __name__ + ".ThirdPartyRulesTestModule",
|
||||
"config": {},
|
||||
}
|
||||
return config
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver()
|
||||
|
||||
load_legacy_third_party_event_rules(hs)
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
# Create a user and room to play with during the tests
|
||||
self.user_id = self.register_user("kermit", "monkey")
|
||||
self.tok = self.login("kermit", "monkey")
|
||||
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||
# Some tests might prevent room creation on purpose.
|
||||
try:
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_third_party_rules(self):
|
||||
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
|
||||
|
@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
|||
# patch the rules module with a Mock which will return False for some event
|
||||
# types
|
||||
async def check(ev, state):
|
||||
return ev.type != "foo.bar.forbidden"
|
||||
return ev.type != "foo.bar.forbidden", None
|
||||
|
||||
callback = Mock(spec=[], side_effect=check)
|
||||
current_rules_module().check_event_allowed = callback
|
||||
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
|
||||
callback
|
||||
]
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
|
@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
|||
# first patch the event checker so that it will try to modify the event
|
||||
async def check(ev: EventBase, state):
|
||||
ev.content = {"x": "y"}
|
||||
return True
|
||||
return True, None
|
||||
|
||||
current_rules_module().check_event_allowed = check
|
||||
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
|
||||
|
||||
# now send the event
|
||||
channel = self.make_request(
|
||||
|
@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
|||
{"x": "x"},
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.assertEqual(channel.result["code"], b"500", channel.result)
|
||||
# check_event_allowed has some error handling, so it shouldn't 500 just because a
|
||||
# module did something bad.
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
event_id = channel.json_body["event_id"]
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
ev = channel.json_body
|
||||
self.assertEqual(ev["content"]["x"], "x")
|
||||
|
||||
def test_modify_event(self):
|
||||
"""The module can return a modified version of the event"""
|
||||
|
@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
|||
async def check(ev: EventBase, state):
|
||||
d = ev.get_dict()
|
||||
d["content"] = {"x": "y"}
|
||||
return d
|
||||
return True, d
|
||||
|
||||
current_rules_module().check_event_allowed = check
|
||||
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
|
||||
|
||||
# now send the event
|
||||
channel = self.make_request(
|
||||
|
@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
|||
"msgtype": "m.text",
|
||||
"body": d["content"]["body"].upper(),
|
||||
}
|
||||
return d
|
||||
return True, d
|
||||
|
||||
current_rules_module().check_event_allowed = check
|
||||
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
|
||||
|
||||
# Send an event, then edit it.
|
||||
channel = self.make_request(
|
||||
|
@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(ev["content"]["body"], "EDITED BODY")
|
||||
|
||||
def test_send_event(self):
|
||||
"""Tests that the module can send an event into a room via the module api"""
|
||||
"""Tests that a module can send an event into a room via the module api"""
|
||||
content = {
|
||||
"msgtype": "m.text",
|
||||
"body": "Hello!",
|
||||
|
@ -234,12 +271,59 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
|||
"sender": self.user_id,
|
||||
}
|
||||
event: EventBase = self.get_success(
|
||||
current_rules_module().module_api.create_and_send_event_into_room(
|
||||
event_dict
|
||||
)
|
||||
self.hs.get_module_api().create_and_send_event_into_room(event_dict)
|
||||
)
|
||||
|
||||
self.assertEquals(event.sender, self.user_id)
|
||||
self.assertEquals(event.room_id, self.room_id)
|
||||
self.assertEquals(event.type, "m.room.message")
|
||||
self.assertEquals(event.content, content)
|
||||
|
||||
@unittest.override_config(
|
||||
{
|
||||
"third_party_event_rules": {
|
||||
"module": __name__ + ".LegacyChangeEvents",
|
||||
"config": {},
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_legacy_check_event_allowed(self):
|
||||
"""Tests that the wrapper for legacy check_event_allowed callbacks works
|
||||
correctly.
|
||||
"""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
"/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id,
|
||||
{
|
||||
"msgtype": "m.text",
|
||||
"body": "Original body",
|
||||
},
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||
|
||||
event_id = channel.json_body["event_id"]
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||
|
||||
self.assertIn("foo", channel.json_body["content"].keys())
|
||||
self.assertEqual(channel.json_body["content"]["foo"], "bar")
|
||||
|
||||
@unittest.override_config(
|
||||
{
|
||||
"third_party_event_rules": {
|
||||
"module": __name__ + ".LegacyDenyNewRooms",
|
||||
"config": {},
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_legacy_on_create_room(self):
|
||||
"""Tests that the wrapper for legacy on_create_room callbacks works
|
||||
correctly.
|
||||
"""
|
||||
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
|
||||
|
|
|
@ -19,7 +19,7 @@ import json
|
|||
import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Mapping, MutableMapping, Optional
|
||||
from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import attr
|
||||
|
@ -53,6 +53,9 @@ class RestHelper:
|
|||
tok: str = None,
|
||||
expect_code: int = 200,
|
||||
extra_content: Optional[Dict] = None,
|
||||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a room.
|
||||
|
@ -87,6 +90,7 @@ class RestHelper:
|
|||
"POST",
|
||||
path,
|
||||
json.dumps(content).encode("utf8"),
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
|
||||
assert channel.result["code"] == b"%d" % expect_code, channel.result
|
||||
|
@ -175,14 +179,30 @@ class RestHelper:
|
|||
|
||||
self.auth_user_id = temp_id
|
||||
|
||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||
def send(
|
||||
self,
|
||||
room_id,
|
||||
body=None,
|
||||
txn_id=None,
|
||||
tok=None,
|
||||
expect_code=200,
|
||||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
):
|
||||
if body is None:
|
||||
body = "body_text_here"
|
||||
|
||||
content = {"msgtype": "m.text", "body": body}
|
||||
|
||||
return self.send_event(
|
||||
room_id, "m.room.message", content, txn_id, tok, expect_code
|
||||
room_id,
|
||||
"m.room.message",
|
||||
content,
|
||||
txn_id,
|
||||
tok,
|
||||
expect_code,
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
|
||||
def send_event(
|
||||
|
@ -193,6 +213,9 @@ class RestHelper:
|
|||
txn_id=None,
|
||||
tok=None,
|
||||
expect_code=200,
|
||||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
|
@ -207,6 +230,7 @@ class RestHelper:
|
|||
"PUT",
|
||||
path,
|
||||
json.dumps(content or {}).encode("utf8"),
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
|
||||
assert (
|
||||
|
|
|
@ -168,6 +168,7 @@ class StateTestCase(unittest.TestCase):
|
|||
"get_state_handler",
|
||||
"get_clock",
|
||||
"get_state_resolution_handler",
|
||||
"get_account_validity_handler",
|
||||
"hostname",
|
||||
]
|
||||
)
|
||||
|
|
|
@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase):
|
|||
user_id = channel.json_body["user_id"]
|
||||
return user_id
|
||||
|
||||
def login(self, username, password, device_id=None):
|
||||
def login(
|
||||
self,
|
||||
username,
|
||||
password,
|
||||
device_id=None,
|
||||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
):
|
||||
"""
|
||||
Log in a user, and get an access token. Requires the Login API be
|
||||
registered.
|
||||
|
@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase):
|
|||
body["device_id"] = device_id
|
||||
|
||||
channel = self.make_request(
|
||||
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
|
||||
"POST",
|
||||
"/_matrix/client/r0/login",
|
||||
json.dumps(body).encode("utf8"),
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
|
|
Loading…
Reference in a new issue