Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Andrew Morgan 2021-07-20 11:48:03 +01:00
commit c0121d69e7
99 changed files with 1611 additions and 683 deletions

View file

@ -344,3 +344,15 @@ jobs:
env:
COMPLEMENT_BASE_IMAGE: complement-synapse:latest
working-directory: complement
# a job which marks all the other jobs as complete, thus allowing PRs to be merged.
tests-done:
needs:
- trial
- trial-olddeps
- sytest
- portdb
- complement
runs-on: ubuntu-latest
steps:
- run: "true"

View file

@ -0,0 +1 @@
Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric.

1
changelog.d/10348.misc Normal file
View file

@ -0,0 +1 @@
Run `pyupgrade` on the codebase.

1
changelog.d/10382.misc Normal file
View file

@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.

View file

@ -0,0 +1 @@
The third-party event rules module interface is deprecated in favour of the generic module interface introduced in Synapse v1.37.0. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information.

1
changelog.d/10404.bugfix Normal file
View file

@ -0,0 +1 @@
Responses from `/make_{join,leave,knock}` no longer include signatures, which will turn out to be invalid after events are returned to `/send_{join,leave,knock}`.

1
changelog.d/10414.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a number of logged errors caused by remote servers being down.

1
changelog.d/10418.misc Normal file
View file

@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.

1
changelog.d/10421.misc Normal file
View file

@ -0,0 +1 @@
Remove unused `events_by_room` code (tech debt).

View file

@ -0,0 +1 @@
Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric.

1
changelog.d/10430.misc Normal file
View file

@ -0,0 +1 @@
Add a github actions job recording success of other jobs.

1
changelog.d/9884.feature Normal file
View file

@ -0,0 +1 @@
Add a module type for the account validity feature.

View file

@ -63,7 +63,7 @@ Modules can register web resources onto Synapse's web server using the following
API method:
```python
def ModuleApi.register_web_resource(path: str, resource: IResource)
def ModuleApi.register_web_resource(path: str, resource: IResource) -> None
```
The path is the full absolute path to register the resource at. For example, if you
@ -91,12 +91,17 @@ are split in categories. A single module may implement callbacks from multiple c
and is under no obligation to implement all callbacks from the categories it registers
callbacks for.
Modules can register callbacks using one of the module API's `register_[...]_callbacks`
methods. The callback functions are passed to these methods as keyword arguments, with
the callback name as the argument name and the function as its value. This is demonstrated
in the example below. A `register_[...]_callbacks` method exists for each module type
documented in this section.
#### Spam checker callbacks
To register one of the callbacks described in this section, a module needs to use the
module API's `register_spam_checker_callbacks` method. The callback functions are passed
to `register_spam_checker_callbacks` as keyword arguments, with the callback name as the
argument name and the function as its value. This is demonstrated in the example below.
Spam checker callbacks allow module developers to implement spam mitigation actions for
Synapse instances. Spam checker callbacks can be registered using the module API's
`register_spam_checker_callbacks` method.
The available spam checker callbacks are:
@ -115,7 +120,7 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool
Called when processing an invitation. The module must return a `bool` indicating whether
the inviter can invite the invitee to the given room. Both inviter and invitee are
represented by their Matrix user ID (i.e. `@alice:example.com`).
represented by their Matrix user ID (e.g. `@alice:example.com`).
```python
async def user_may_create_room(user: str) -> bool
@ -181,13 +186,103 @@ The arguments passed to this callback are:
```python
async def check_media_file_for_spam(
file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper",
file_info: "synapse.rest.media.v1._base.FileInfo"
file_info: "synapse.rest.media.v1._base.FileInfo",
) -> bool
```
Called when storing a local or remote file. The module must return a boolean indicating
whether the given file can be stored in the homeserver's media store.
#### Account validity callbacks
Account validity callbacks allow module developers to add extra steps to verify the
validity on an account, i.e. see if a user can be granted access to their account on the
Synapse instance. Account validity callbacks can be registered using the module API's
`register_account_validity_callbacks` method.
The available account validity callbacks are:
```python
async def is_user_expired(user: str) -> Optional[bool]
```
Called when processing any authenticated request (except for logout requests). The module
can return a `bool` to indicate whether the user has expired and should be locked out of
their account, or `None` if the module wasn't able to figure it out. The user is
represented by their Matrix user ID (e.g. `@alice:example.com`).
If the module returns `True`, the current request will be denied with the error code
`ORG_MATRIX_EXPIRED_ACCOUNT` and the HTTP status code 403. Note that this doesn't
invalidate the user's access token.
```python
async def on_user_registration(user: str) -> None
```
Called after successfully registering a user, in case the module needs to perform extra
operations to keep track of them. (e.g. add them to a database table). The user is
represented by their Matrix user ID.
#### Third party rules callbacks
Third party rules callbacks allow module developers to add extra checks to verify the
validity of incoming events. Third party event rules callbacks can be registered using
the module API's `register_third_party_rules_callbacks` method.
The available third party rules callbacks are:
```python
async def check_event_allowed(
event: "synapse.events.EventBase",
state_events: "synapse.types.StateMap",
) -> Tuple[bool, Optional[dict]]
```
**<span style="color:red">
This callback is very experimental and can and will break without notice. Module developers
are encouraged to implement `check_event_for_spam` from the spam checker category instead.
</span>**
Called when processing any incoming event, with the event and a `StateMap`
representing the current state of the room the event is being sent into. A `StateMap` is
a dictionary that maps tuples containing an event type and a state key to the
corresponding state event. For example retrieving the room's `m.room.create` event from
the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`.
The module must return a boolean indicating whether the event can be allowed.
Note that this callback function processes incoming events coming via federation
traffic (on top of client traffic). This means denying an event might cause the local
copy of the room's history to diverge from that of remote servers. This may cause
federation issues in the room. It is strongly recommended to only deny events using this
callback function if the sender is a local user, or in a private federation in which all
servers are using the same module, with the same configuration.
If the boolean returned by the module is `True`, it may also tell Synapse to replace the
event with new data by returning the new event's data as a dictionary. In order to do
that, it is recommended the module calls `event.get_dict()` to get the current event as a
dictionary, and modify the returned dictionary accordingly.
Note that replacing the event only works for events sent by local users, not for events
received over federation.
```python
async def on_create_room(
requester: "synapse.types.Requester",
request_content: dict,
is_requester_admin: bool,
) -> None
```
Called when processing a room creation request, with the `Requester` object for the user
performing the request, a dictionary representing the room creation request's JSON body
(see [the spec](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-createroom)
for a list of possible parameters), and a boolean indicating whether the user performing
the request is a server admin.
Modules can modify the `request_content` (by e.g. adding events to its `initial_state`),
or deny the room's creation by raising a `module_api.errors.SynapseError`.
### Porting an existing module that uses the old interface
In order to port a module that uses Synapse's old module interface, its author needs to:

View file

@ -1310,91 +1310,6 @@ account_threepid_delegates:
#auto_join_rooms_for_guests: false
## Account Validity ##
# Optional account validity configuration. This allows for accounts to be denied
# any request after a given period.
#
# Once this feature is enabled, Synapse will look for registered users without an
# expiration date at startup and will add one to every account it found using the
# current settings at that time.
# This means that, if a validity period is set, and Synapse is restarted (it will
# then derive an expiration date from the current validity period), and some time
# after that the validity period changes and Synapse is restarted, the users'
# expiration dates won't be updated unless their account is manually renewed. This
# date will be randomly selected within a range [now + period - d ; now + period],
# where d is equal to 10% of the validity period.
#
account_validity:
# The account validity feature is disabled by default. Uncomment the
# following line to enable it.
#
#enabled: true
# The period after which an account is valid after its registration. When
# renewing the account, its validity period will be extended by this amount
# of time. This parameter is required when using the account validity
# feature.
#
#period: 6w
# The amount of time before an account's expiry date at which Synapse will
# send an email to the account's email address with a renewal link. By
# default, no such emails are sent.
#
# If you enable this setting, you will also need to fill out the 'email' and
# 'public_baseurl' configuration sections.
#
#renew_at: 1w
# The subject of the email sent out with the renewal link. '%(app)s' can be
# used as a placeholder for the 'app_name' parameter from the 'email'
# section.
#
# Note that the placeholder must be written '%(app)s', including the
# trailing 's'.
#
# If this is not set, a default value is used.
#
#renew_email_subject: "Renew your %(app)s account"
# Directory in which Synapse will try to find templates for the HTML files to
# serve to the user when trying to renew an account. If not set, default
# templates from within the Synapse package will be used.
#
# The currently available templates are:
#
# * account_renewed.html: Displayed to the user after they have successfully
# renewed their account.
#
# * account_previously_renewed.html: Displayed to the user if they attempt to
# renew their account with a token that is valid, but that has already
# been used. In this case the account is not renewed again.
#
# * invalid_token.html: Displayed to the user when they try to renew an account
# with an unknown or invalid renewal token.
#
# See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for
# default template contents.
#
# The file name of some of these templates can be configured below for legacy
# reasons.
#
#template_dir: "res/templates"
# A custom file name for the 'account_renewed.html' template.
#
# If not set, the file is assumed to be named "account_renewed.html".
#
#account_renewed_html_path: "account_renewed.html"
# A custom file name for the 'invalid_token.html' template.
#
# If not set, the file is assumed to be named "invalid_token.html".
#
#invalid_token_html_path: "invalid_token.html"
## Metrics ###
# Enable collection and rendering of performance metrics
@ -2739,19 +2654,6 @@ stats:
# action: allow
# Server admins can define a Python module that implements extra rules for
# allowing or denying incoming events. In order to work, this module needs to
# override the methods defined in synapse/events/third_party_rules.py.
#
# This feature is designed to be used in closed federations only, where each
# participating server enforces the same rules.
#
#third_party_event_rules:
# module: "my_custom_project.SuperRulesSet"
# config:
# example_option: 'things'
## Opentracing ##
# These settings enable opentracing, which implements distributed tracing.

View file

@ -86,6 +86,19 @@ process, for example:
```
# Upgrading to v1.39.0
## Deprecation of the current third-party rules module interface
The current third-party rules module interface is deprecated in favour of the new generic
modules system introduced in Synapse v1.37.0. Authors of third-party rules modules can refer
to [this documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface)
to update their modules. Synapse administrators can refer to [this documentation](modules.md#using-modules)
to update their configuration once the modules they are using have been updated.
We plan to remove support for the current third-party rules interface in September 2021.
# Upgrading to v1.38.0
## Re-indexing of `events` table on Postgres databases

View file

@ -62,6 +62,7 @@ class Auth:
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self._account_validity_handler = hs.get_account_validity_handler()
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
10000, "token_cache"
@ -69,9 +70,6 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs)
self._account_validity_enabled = (
hs.config.account_validity.account_validity_enabled
)
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
@ -187,12 +185,17 @@ class Auth:
shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
if self._account_validity_enabled and not allow_expired:
if await self.store.is_account_expired(
user_info.user_id, self.clock.time_msec()
if not allow_expired:
if await self._account_validity_handler.is_user_expired(
user_info.user_id
):
# Raise the error if either an account validity module has determined
# the account has expired, or the legacy account validity
# implementation is enabled and determined the account has expired
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
403,
"User account has expired",
errcode=Codes.EXPIRED_ACCOUNT,
)
device_id = user_info.device_id

View file

@ -38,6 +38,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
@ -368,6 +369,7 @@ async def start(hs: "HomeServer"):
module(config=config, api=module_api)
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
# If we've configured an expiry time for caches, start the background job now.
setup_expire_lru_cache_entries(hs)

View file

@ -395,10 +395,8 @@ class GenericWorkerServer(HomeServer):
elif listener.type == "metrics":
if not self.config.enable_metrics:
logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
"Metrics listener configured, but "
"enable_metrics is not True!"
)
else:
_base.listen_metrics(listener.bind_addresses, listener.port)

View file

@ -305,10 +305,8 @@ class SynapseHomeServer(HomeServer):
elif listener.type == "metrics":
if not self.config.enable_metrics:
logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
"Metrics listener configured, but "
"enable_metrics is not True!"
)
else:
_base.listen_metrics(listener.bind_addresses, listener.port)

View file

@ -71,6 +71,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
# General statistics
#
store = hs.get_datastore()
stats["homeserver"] = hs.config.server_name
stats["server_context"] = hs.config.server_context
stats["timestamp"] = now
@ -79,34 +81,38 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["python_version"] = "{}.{}.{}".format(
version.major, version.minor, version.micro
)
stats["total_users"] = await hs.get_datastore().count_all_users()
stats["total_users"] = await store.count_all_users()
total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
total_nonbridged_users = await store.count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = await hs.get_datastore().count_daily_user_type()
daily_user_type_results = await store.count_daily_user_type()
for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
room_count = await hs.get_datastore().get_room_count()
room_count = await store.get_room_count()
stats["total_room_count"] = room_count
stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
stats["daily_active_users"] = await store.count_daily_users()
stats["monthly_active_users"] = await store.count_monthly_users()
daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms()
stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
stats["daily_e2ee_messages"] = await store.count_daily_e2ee_messages()
daily_sent_e2ee_messages = await store.count_daily_sent_e2ee_messages()
stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
stats["daily_active_rooms"] = await store.count_daily_active_rooms()
stats["daily_messages"] = await store.count_daily_messages()
daily_sent_messages = await store.count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
r30_results = await hs.get_datastore().count_r30_users()
r30_results = await store.count_r30_users()
for name, count in r30_results.items():
stats["r30_users_" + name] = count
r30v2_results = await store.count_r30_users()
for name, count in r30v2_results.items():
stats["r30v2_users_" + name] = count
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
@ -115,8 +121,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
#
# This only reports info about the *main* database.
stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
stats["database_engine"] = store.db_pool.engine.module.__name__
stats["database_server_version"] = store.db_pool.engine.server_version
#
# Logging configuration

View file

@ -18,6 +18,21 @@ class AccountValidityConfig(Config):
section = "account_validity"
def read_config(self, config, **kwargs):
"""Parses the old account validity config. The config format looks like this:
account_validity:
enabled: true
period: 6w
renew_at: 1w
renew_email_subject: "Renew your %(app)s account"
template_dir: "res/templates"
account_renewed_html_path: "account_renewed.html"
invalid_token_html_path: "invalid_token.html"
We expect admins to use modules for this feature (which is why it doesn't appear
in the sample config file), but we want to keep support for it around for a bit
for backwards compatibility.
"""
account_validity_config = config.get("account_validity") or {}
self.account_validity_enabled = account_validity_config.get("enabled", False)
self.account_validity_renew_by_email_enabled = (
@ -75,90 +90,3 @@ class AccountValidityConfig(Config):
],
account_validity_template_dir,
)
def generate_config_section(self, **kwargs):
return """\
## Account Validity ##
# Optional account validity configuration. This allows for accounts to be denied
# any request after a given period.
#
# Once this feature is enabled, Synapse will look for registered users without an
# expiration date at startup and will add one to every account it found using the
# current settings at that time.
# This means that, if a validity period is set, and Synapse is restarted (it will
# then derive an expiration date from the current validity period), and some time
# after that the validity period changes and Synapse is restarted, the users'
# expiration dates won't be updated unless their account is manually renewed. This
# date will be randomly selected within a range [now + period - d ; now + period],
# where d is equal to 10% of the validity period.
#
account_validity:
# The account validity feature is disabled by default. Uncomment the
# following line to enable it.
#
#enabled: true
# The period after which an account is valid after its registration. When
# renewing the account, its validity period will be extended by this amount
# of time. This parameter is required when using the account validity
# feature.
#
#period: 6w
# The amount of time before an account's expiry date at which Synapse will
# send an email to the account's email address with a renewal link. By
# default, no such emails are sent.
#
# If you enable this setting, you will also need to fill out the 'email' and
# 'public_baseurl' configuration sections.
#
#renew_at: 1w
# The subject of the email sent out with the renewal link. '%(app)s' can be
# used as a placeholder for the 'app_name' parameter from the 'email'
# section.
#
# Note that the placeholder must be written '%(app)s', including the
# trailing 's'.
#
# If this is not set, a default value is used.
#
#renew_email_subject: "Renew your %(app)s account"
# Directory in which Synapse will try to find templates for the HTML files to
# serve to the user when trying to renew an account. If not set, default
# templates from within the Synapse package will be used.
#
# The currently available templates are:
#
# * account_renewed.html: Displayed to the user after they have successfully
# renewed their account.
#
# * account_previously_renewed.html: Displayed to the user if they attempt to
# renew their account with a token that is valid, but that has already
# been used. In this case the account is not renewed again.
#
# * invalid_token.html: Displayed to the user when they try to renew an account
# with an unknown or invalid renewal token.
#
# See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for
# default template contents.
#
# The file name of some of these templates can be configured below for legacy
# reasons.
#
#template_dir: "res/templates"
# A custom file name for the 'account_renewed.html' template.
#
# If not set, the file is assumed to be named "account_renewed.html".
#
#account_renewed_html_path: "account_renewed.html"
# A custom file name for the 'invalid_token.html' template.
#
# If not set, the file is assumed to be named "invalid_token.html".
#
#invalid_token_html_path: "invalid_token.html"
"""

View file

@ -64,7 +64,7 @@ def load_appservices(hostname, config_files):
for config_file in config_files:
try:
with open(config_file, "r") as f:
with open(config_file) as f:
appservice = _load_appservice(hostname, yaml.safe_load(f), config_file)
if appservice.id in seen_ids:
raise ConfigError(

View file

@ -28,18 +28,3 @@ class ThirdPartyRulesConfig(Config):
self.third_party_event_rules = load_module(
provider, ("third_party_event_rules",)
)
def generate_config_section(self, **kwargs):
return """\
# Server admins can define a Python module that implements extra rules for
# allowing or denying incoming events. In order to work, this module needs to
# override the methods defined in synapse/events/third_party_rules.py.
#
# This feature is designed to be used in closed federations only, where each
# participating server enforces the same rules.
#
#third_party_event_rules:
# module: "my_custom_project.SuperRulesSet"
# config:
# example_option: 'things'
"""

View file

@ -66,10 +66,8 @@ class TlsConfig(Config):
if self.federation_client_minimum_tls_version == "1.3":
if getattr(SSL, "OP_NO_TLSv1_3", None) is None:
raise ConfigError(
(
"federation_client_minimum_tls_version cannot be 1.3, "
"your OpenSSL does not support it"
)
"federation_client_minimum_tls_version cannot be 1.3, "
"your OpenSSL does not support it"
)
# Whitelist of domains to not verify certificates for

View file

@ -291,6 +291,20 @@ class EventBase(metaclass=abc.ABCMeta):
return pdu_json
def get_templated_pdu_json(self) -> JsonDict:
"""
Return a JSON object suitable for a templated event, as used in the
make_{join,leave,knock} workflow.
"""
# By using _dict directly we don't pull in signatures/unsigned.
template_json = dict(self._dict)
# The hashes (similar to the signature) need to be recalculated by the
# joining/leaving/knocking server after (potentially) modifying the
# event.
template_json.pop("hashes")
return template_json
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))

View file

@ -11,16 +11,124 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Union
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester, StateMap
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
CHECK_EVENT_ALLOWED_CALLBACK = Callable[
[EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
]
ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
[str, str, StateMap[EventBase]], Awaitable[bool]
]
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
[str, StateMap[EventBase], str], Awaitable[bool]
]
def load_legacy_third_party_event_rules(hs: "HomeServer"):
"""Wrapper that loads a third party event rules module configured using the old
configuration, and registers the hooks they implement.
"""
if hs.config.third_party_event_rules is None:
return
module, config = hs.config.third_party_event_rules
api = hs.get_module_api()
third_party_rules = module(config=config, module_api=api)
# The known hooks. If a module implements a method which name appears in this set,
# we'll want to register it.
third_party_event_rules_methods = {
"check_event_allowed",
"on_create_room",
"check_threepid_can_be_invited",
"check_visibility_can_be_modified",
}
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
# We return a separate wrapper for these methods because, in order to wrap them
# correctly, we need to await its result. Therefore it doesn't make a lot of
# sense to make it go through the run() wrapper.
if f.__name__ == "check_event_allowed":
# We need to wrap check_event_allowed because its old form would return either
# a boolean or a dict, but now we want to return the dict separately from the
# boolean.
async def wrap_check_event_allowed(
event: EventBase,
state_events: StateMap[EventBase],
) -> Tuple[bool, Optional[dict]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
res = await f(event, state_events)
if isinstance(res, dict):
return True, res
else:
return res, None
return wrap_check_event_allowed
if f.__name__ == "on_create_room":
# We need to wrap on_create_room because its old form would return a boolean
# if the room creation is denied, but now we just want it to raise an
# exception.
async def wrap_on_create_room(
requester: Requester, config: dict, is_requester_admin: bool
) -> None:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
res = await f(requester, config, is_requester_admin)
if res is False:
raise SynapseError(
403,
"Room creation forbidden with these parameters",
)
return wrap_on_create_room
def run(*args, **kwargs):
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks = {
hook: async_wrapper(getattr(third_party_rules, hook, None))
for hook in third_party_event_rules_methods
}
api.register_third_party_rules_callbacks(**hooks)
class ThirdPartyEventRules:
"""Allows server admins to provide a Python module implementing an extra
@ -35,36 +143,65 @@ class ThirdPartyEventRules:
self.store = hs.get_datastore()
module = None
config = None
if hs.config.third_party_event_rules:
module, config = hs.config.third_party_event_rules
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
self._check_threepid_can_be_invited_callbacks: List[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
] = []
self._check_visibility_can_be_modified_callbacks: List[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = []
if module is not None:
self.third_party_rules = module(
config=config,
module_api=hs.get_module_api(),
def register_third_party_rules_callbacks(
self,
check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
check_threepid_can_be_invited: Optional[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
] = None,
check_visibility_can_be_modified: Optional[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = None,
):
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
self._check_event_allowed_callbacks.append(check_event_allowed)
if on_create_room is not None:
self._on_create_room_callbacks.append(on_create_room)
if check_threepid_can_be_invited is not None:
self._check_threepid_can_be_invited_callbacks.append(
check_threepid_can_be_invited,
)
if check_visibility_can_be_modified is not None:
self._check_visibility_can_be_modified_callbacks.append(
check_visibility_can_be_modified,
)
async def check_event_allowed(
self, event: EventBase, context: EventContext
) -> Union[bool, dict]:
) -> Tuple[bool, Optional[dict]]:
"""Check if a provided event should be allowed in the given context.
The module can return:
* True: the event is allowed.
* False: the event is not allowed, and should be rejected with M_FORBIDDEN.
* a dict: replacement event data.
If the event is allowed, the module can also return a dictionary to use as a
replacement for the event.
Args:
event: The event to be checked.
context: The context of the event.
Returns:
The result from the ThirdPartyRules module, as above
The result from the ThirdPartyRules module, as above.
"""
if self.third_party_rules is None:
return True
# Bail out early without hitting the store if we don't have any callbacks to run.
if len(self._check_event_allowed_callbacks) == 0:
return True, None
prev_state_ids = await context.get_prev_state_ids()
@ -77,29 +214,46 @@ class ThirdPartyEventRules:
# the hashes and signatures.
event.freeze()
return await self.third_party_rules.check_event_allowed(event, state_events)
for callback in self._check_event_allowed_callbacks:
try:
res, replacement_data = await callback(event, state_events)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
# Return if the event shouldn't be allowed or if the module came up with a
# replacement dict for the event.
if res is False:
return res, None
elif isinstance(replacement_data, dict):
return True, replacement_data
return True, None
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
) -> bool:
"""Intercept requests to create room to allow, deny or update the
request config.
) -> None:
"""Intercept requests to create room to maybe deny it (via an exception) or
update the request config.
Args:
requester
config: The creation config from the client.
is_requester_admin: If the requester is an admin
Returns:
Whether room creation is allowed or denied.
"""
for callback in self._on_create_room_callbacks:
try:
await callback(requester, config, is_requester_admin)
except Exception as e:
# Don't silence the errors raised by this callback since we expect it to
# raise an exception to deny the creation of the room; instead make sure
# it's a SynapseError we can send to clients.
if not isinstance(e, SynapseError):
e = SynapseError(
403, "Room creation forbidden with these parameters"
)
if self.third_party_rules is None:
return True
return await self.third_party_rules.on_create_room(
requester, config, is_requester_admin
)
raise e
async def check_threepid_can_be_invited(
self, medium: str, address: str, room_id: str
@ -114,15 +268,20 @@ class ThirdPartyEventRules:
Returns:
True if the 3PID can be invited, False if not.
"""
if self.third_party_rules is None:
# Bail out early without hitting the store if we don't have any callbacks to run.
if len(self._check_threepid_can_be_invited_callbacks) == 0:
return True
state_events = await self._get_state_map_for_room(room_id)
return await self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events
)
for callback in self._check_threepid_can_be_invited_callbacks:
try:
if await callback(medium, address, state_events) is False:
return False
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
return True
async def check_visibility_can_be_modified(
self, room_id: str, new_visibility: str
@ -137,18 +296,20 @@ class ThirdPartyEventRules:
Returns:
True if the room's visibility can be modified, False if not.
"""
if self.third_party_rules is None:
return True
check_func = getattr(
self.third_party_rules, "check_visibility_can_be_modified", None
)
if not check_func or not callable(check_func):
# Bail out early without hitting the store if we don't have any callback
if len(self._check_visibility_can_be_modified_callbacks) == 0:
return True
state_events = await self._get_state_map_for_room(room_id)
return await check_func(room_id, state_events, new_visibility)
for callback in self._check_visibility_can_be_modified_callbacks:
try:
if await callback(room_id, state_events, new_visibility) is False:
return False
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
return True
async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
"""Given a room ID, return the state events of that room.

View file

@ -562,8 +562,7 @@ class FederationServer(FederationBase):
raise IncompatibleRoomVersionError(room_version=room_version)
pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
async def on_invite_request(
self, origin: str, content: JsonDict, room_version_id: str
@ -611,8 +610,7 @@ class FederationServer(FederationBase):
room_version = await self.store.get_room_version_id(room_id)
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
async def on_send_leave_request(
self, origin: str, content: JsonDict, room_id: str
@ -659,9 +657,8 @@ class FederationServer(FederationBase):
)
pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
return {
"event": pdu.get_pdu_json(time_now),
"event": pdu.get_templated_pdu_json(),
"room_version": room_version.identifier,
}

View file

@ -38,10 +38,10 @@ class BaseHandler:
"""
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs
@ -55,12 +55,12 @@ class BaseHandler:
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
) # type: Optional[Ratelimiter]
)
else:
self.admin_redaction_ratelimiter = None

View file

@ -15,9 +15,11 @@
import email.mime.multipart
import email.utils
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
from synapse.api.errors import StoreError, SynapseError
from twisted.web.http import Request
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
@ -27,6 +29,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Types for callbacks to be registered via the module api
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
class AccountValidityHandler:
def __init__(self, hs: "HomeServer"):
@ -70,6 +81,99 @@ class AccountValidityHandler:
if hs.config.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
self._on_legacy_send_mail_callback: Optional[
ON_LEGACY_SEND_MAIL_CALLBACK
] = None
self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
# The legacy admin requests callback isn't a protected attribute because we need
# to access it from the admin servlet, which is outside of this handler.
self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None
def register_account_validity_callbacks(
self,
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
):
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired)
if on_user_registration is not None:
self._on_user_registration_callbacks.append(on_user_registration)
# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
# an admin one). As part of moving the feature into a module, we need to change
# the path from /_matrix/client/unstable/account_validity/... to
# /_synapse/client/account_validity, because:
#
# * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
# * the way we register servlets means that modules can't register resources
# under /_matrix/client
#
# We need to allow for a transition period between the old and new endpoints
# in order to allow for clients to update (and for emails to be processed).
#
# Once the email-account-validity module is loaded, it will take control of account
# validity by moving the rows from our `account_validity` table into its own table.
#
# Therefore, we need to allow modules (in practice just the one implementing the
# email-based account validity) to temporarily hook into the legacy endpoints so we
# can route the traffic coming into the old endpoints into the module, which is
# why we have the following three temporary hooks.
if on_legacy_send_mail is not None:
if self._on_legacy_send_mail_callback is not None:
raise RuntimeError("Tried to register on_legacy_send_mail twice")
self._on_legacy_send_mail_callback = on_legacy_send_mail
if on_legacy_renew is not None:
if self._on_legacy_renew_callback is not None:
raise RuntimeError("Tried to register on_legacy_renew twice")
self._on_legacy_renew_callback = on_legacy_renew
if on_legacy_admin_request is not None:
if self.on_legacy_admin_request_callback is not None:
raise RuntimeError("Tried to register on_legacy_admin_request twice")
self.on_legacy_admin_request_callback = on_legacy_admin_request
async def is_user_expired(self, user_id: str) -> bool:
"""Checks if a user has expired against third-party modules.
Args:
user_id: The user to check the expiry of.
Returns:
Whether the user has expired.
"""
for callback in self._is_user_expired_callbacks:
expired = await callback(user_id)
if expired is not None:
return expired
if self._account_validity_enabled:
# If no module could determine whether the user has expired and the legacy
# configuration is enabled, fall back to it.
return await self.store.is_account_expired(user_id, self.clock.time_msec())
return False
async def on_user_registration(self, user_id: str):
"""Tell third-party modules about a user's registration.
Args:
user_id: The ID of the newly registered user.
"""
for callback in self._on_user_registration_callbacks:
await callback(user_id)
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time
@ -95,6 +199,17 @@ class AccountValidityHandler:
Raises:
SynapseError if the user is not set to renew.
"""
# If a module supports sending a renewal email from here, do that, otherwise do
# the legacy dance.
if self._on_legacy_send_mail_callback is not None:
await self._on_legacy_send_mail_callback(user_id)
return
if not self._account_validity_renew_by_email_enabled:
raise AuthError(
403, "Account renewal via email is disabled on this server."
)
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
# If this user isn't set to be expired, raise an error.
@ -209,6 +324,10 @@ class AccountValidityHandler:
token is considered stale. A token is stale if the 'token_used_ts_ms' db column
is non-null.
This method exists to support handling the legacy account validity /renew
endpoint. If a module implements the on_legacy_renew callback, then this process
is delegated to the module instead.
Args:
renewal_token: Token sent with the renewal request.
Returns:
@ -218,6 +337,11 @@ class AccountValidityHandler:
* An int representing the user's expiry timestamp as milliseconds since the
epoch, or 0 if the token was invalid.
"""
# If a module supports triggering a renew from here, do that, otherwise do the
# legacy dance.
if self._on_legacy_renew_callback is not None:
return await self._on_legacy_renew_callback(renewal_token)
try:
(
user_id,

View file

@ -139,7 +139,7 @@ class AdminHandler(BaseHandler):
to_key = RoomStreamToken(None, stream_ordering)
# Events that we've processed in this room
written_events = set() # type: Set[str]
written_events: Set[str] = set()
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
@ -152,7 +152,7 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
# events "children".
unseen_to_child_events = {} # type: Dict[str, Set[str]]
unseen_to_child_events: Dict[str, Set[str]] = {}
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most

View file

@ -96,7 +96,7 @@ class ApplicationServicesHandler:
self.current_max, limit
)
events_by_room = {} # type: Dict[str, List[EventBase]]
events_by_room: Dict[str, List[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
@ -275,7 +275,7 @@ class ApplicationServicesHandler:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
events = [] # type: List[JsonDict]
events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
@ -375,7 +375,7 @@ class ApplicationServicesHandler:
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]]
protocols: Dict[str, List[JsonDict]] = {}
# Collect up all the individual protocol responses out of the ASes
for s in services:

View file

@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
@ -296,7 +296,7 @@ class AuthHandler(BaseHandler):
# A mapping of user ID to extra attributes to include in the login
# response.
self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
async def validate_user_via_ui_auth(
self,
@ -500,7 +500,7 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows.
"""
sid = None # type: Optional[str]
sid: Optional[str] = None
authdict = clientdict.pop("auth", {})
if "session" in authdict:
sid = authdict["session"]
@ -588,9 +588,9 @@ class AuthHandler(BaseHandler):
)
# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
errordict: Dict[str, Any] = {}
if "type" in authdict:
login_type = authdict["type"] # type: str
login_type: str = authdict["type"]
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms,
}
params = {} # type: Dict[str, Any]
params: Dict[str, Any] = {}
for f in public_flows:
for stage in f:
@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler):
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
user_id_to_verify = await self.get_session_data(
user_id_to_verify: str = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str
)
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify

View file

@ -40,7 +40,7 @@ class CasError(Exception):
def __str__(self):
if self.error_description:
return "{}: {}".format(self.error, self.error_description)
return f"{self.error}: {self.error_description}"
return self.error
@ -171,7 +171,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {} # type: Dict[str, List[Optional[str]]]
attributes: Dict[str, List[Optional[str]]] = {}
for child in root[0]:
if child.tag.endswith("user"):
user = child.text

View file

@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id
)
hosts = set() # type: Set[str]
hosts: Set[str] = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
@ -613,20 +613,20 @@ class DeviceListUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
# user_id -> list of updates waiting to be handled.
self._pending_updates = (
{}
) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
self._pending_updates: Dict[
str, List[Tuple[str, str, Iterable[str], JsonDict]]
] = {}
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
cache_name="device_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
) # type: ExpiringCache[str, Set[str]]
)
# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
@ -755,7 +755,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
seen_updates: Set[str] = self._seen_updates.get(user_id, set())
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)

View file

@ -203,7 +203,7 @@ class DeviceMessageHandler:
log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (

View file

@ -237,9 +237,9 @@ class DirectoryHandler(BaseHandler):
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None
if self.hs.is_mine(room_alias):
result = await self.get_association_from_room_alias(
room_alias
) # type: Optional[RoomAliasMapping]
result: Optional[
RoomAliasMapping
] = await self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id

View file

@ -115,9 +115,9 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get(
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
)
# separate users by domain.
# make a map from domain to user_id to device_ids
@ -136,7 +136,7 @@ class E2eKeysHandler:
# First get local devices.
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
failures: Dict[str, JsonDict] = {}
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
@ -151,11 +151,9 @@ class E2eKeysHandler:
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = (
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
query_list: List[Tuple[str, Optional[str]]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
@ -362,9 +360,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
local_query = [] # type: List[Tuple[str, Optional[str]]]
local_query: List[Tuple[str, Optional[str]]] = []
result_dict = {} # type: Dict[str, Dict[str, dict]]
result_dict: Dict[str, Dict[str, dict]] = {}
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
@ -402,9 +400,9 @@ class E2eKeysHandler:
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
"""Handle a device key query from a federated server"""
device_keys_query = query_body.get(
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
) # type: Dict[str, Optional[List[str]]]
)
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
@ -421,8 +419,8 @@ class E2eKeysHandler:
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
) -> JsonDict:
local_query = [] # type: List[Tuple[str, str, str]]
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
@ -439,8 +437,8 @@ class E2eKeysHandler:
results = await self.store.claim_e2e_one_time_keys(local_query)
# A map of user ID -> device ID -> key ID -> key.
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
failures = {} # type: Dict[str, JsonDict]
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
failures: Dict[str, JsonDict] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
@ -768,8 +766,8 @@ class E2eKeysHandler:
Raises:
SynapseError: if the input is malformed
"""
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
signature_list: List["SignatureListItem"] = []
failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures
@ -930,8 +928,8 @@ class E2eKeysHandler:
Raises:
SynapseError: if the input is malformed
"""
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
signature_list: List["SignatureListItem"] = []
failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures
@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
device_ids = [] # type: List[str]
device_ids: List[str] = []
logger.info("pending updates: %r", pending_updates)

View file

@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler):
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = [] # type: List[JsonDict]
to_add: List[JsonDict] = []
for event in events:
if not isinstance(event, EventBase):
continue
@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = await self.store.get_users_in_room(
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
) # type: Iterable[str]
)
else:
users = [event.state_key]

View file

@ -181,7 +181,7 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples.
self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]]
self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._room_backfill = Linearizer("room_backfill")
@ -368,7 +368,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(ours.values()) # type: List[StateMap[str]]
state_maps: List[StateMap[str]] = list(ours.values())
# we don't need this any more, let's delete it.
del ours
@ -735,7 +735,7 @@ class FederationHandler(BaseHandler):
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
(await self.store.get_events(missing_desired_events, allow_rejected=True))
await self.store.get_events(missing_desired_events, allow_rejected=True)
)
# check for events which were in the wrong room.
@ -845,7 +845,7 @@ class FederationHandler(BaseHandler):
# exact key to expect. Otherwise check it matches any key we
# have for that device.
current_keys = [] # type: Container[str]
current_keys: Container[str] = []
if device:
keys = device.get("keys", {}).get("keys", {})
@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
joined_domains = {} # type: Dict[str, int]
joined_domains: Dict[str, int] = {}
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version(room_id)
event_map = {} # type: Dict[str, EventBase]
event_map: Dict[str, EventBase] = {}
async def get_event(event_id: str):
with nested_logging_context(event_id):
@ -1596,7 +1596,7 @@ class FederationHandler(BaseHandler):
# Ask the remote server to create a valid knock event for us. Once received,
# we sign the event
params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
params: Dict[str, Iterable[str]] = {"ver": supported_room_versions}
origin, event, event_format_version = await self._make_and_verify_event(
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
)
@ -1934,7 +1934,7 @@ class FederationHandler(BaseHandler):
builder=builder
)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@ -2026,7 +2026,7 @@ class FederationHandler(BaseHandler):
# for knock events, we run the third-party event rules. It's not entirely clear
# why we don't do this for other sorts of membership events.
if event.membership == Membership.KNOCK:
event_allowed = await self.third_party_event_rules.check_event_allowed(
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@ -2453,14 +2453,14 @@ class FederationHandler(BaseHandler):
state_sets_d = await self.state_store.get_state_groups(
event.room_id, extrem_ids
)
state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]]
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
} # type: StateMap[str]
}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@ -2817,7 +2817,7 @@ class FederationHandler(BaseHandler):
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
else:
event_key = None
state_updates = {
@ -3156,7 +3156,7 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content)
last_exception = None # type: Optional[Exception]
last_exception: Optional[Exception] = None
# for each public key in the 3pid invite event
for public_key_object in event_auth.get_public_keys(invite_event):

View file

@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler:
async def bulk_get_publicised_groups(
self, user_ids: Iterable[str], proxy: bool = True
) -> JsonDict:
destinations = {} # type: Dict[str, Set[str]]
destinations: Dict[str, Set[str]] = {}
local_users = set()
for user_id in user_ids:
@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
failed_results = [] # type: List[str]
failed_results: List[str] = []
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(

View file

@ -302,7 +302,7 @@ class IdentityHandler(BaseHandler):
)
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
content = {
"mxid": mxid,
@ -695,7 +695,7 @@ class IdentityHandler(BaseHandler):
return data["mxid"]
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except IOError as e:
except OSError as e:
logger.warning("Error from v1 identity server lookup: %s" % (e,))
return None

View file

@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = ResponseCache(
hs.get_clock(), "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self.snapshot_cache: ResponseCache[
Tuple[
str,
Optional[StreamToken],
Optional[StreamToken],
str,
Optional[int],
bool,
bool,
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state

View file

@ -81,7 +81,7 @@ class MessageHandler:
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
self._scheduled_expiry = None # type: Optional[IDelayedCall]
self._scheduled_expiry: Optional[IDelayedCall] = None
if not hs.config.worker_app:
run_as_background_process(
@ -196,9 +196,7 @@ class MessageHandler:
room_state_events = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
room_state = room_state_events[
event.event_id
] # type: Mapping[Any, EventBase]
room_state: Mapping[Any, EventBase] = room_state_events[event.event_id]
else:
raise AuthError(
403,
@ -421,9 +419,9 @@ class EventCreationHandler:
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = (
self.third_party_event_rules: "ThirdPartyEventRules" = (
self.hs.get_third_party_event_rules()
) # type: ThirdPartyEventRules
)
self._block_events_without_consent_error = (
self.config.block_events_without_consent_error
@ -440,7 +438,7 @@ class EventCreationHandler:
#
# map from room id to time-of-last-attempt.
#
self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
# The number of forward extremeities before a dummy event is sent.
self._dummy_events_threshold = hs.config.dummy_events_threshold
@ -465,9 +463,7 @@ class EventCreationHandler:
# Stores the state groups we've recently added to the joined hosts
# external cache. Note that the timeout must be significantly less than
# the TTL on the external cache.
self._external_cache_joined_hosts_updates = (
None
) # type: Optional[ExpiringCache]
self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
if self._external_cache.is_enabled():
self._external_cache_joined_hosts_updates = ExpiringCache(
"_external_cache_joined_hosts_updates",
@ -953,10 +949,10 @@ class EventCreationHandler:
if requester:
context.app_service = requester.app_service
third_party_result = await self.third_party_event_rules.check_event_allowed(
res, new_content = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not third_party_result:
if res is False:
logger.info(
"Event %s forbidden by third-party rules",
event,
@ -964,11 +960,11 @@ class EventCreationHandler:
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
elif isinstance(third_party_result, dict):
elif new_content is not None:
# the third-party rules want to replace the event. We'll need to build a new
# event.
event, context = await self._rebuild_event_after_third_party_rules(
third_party_result, event
new_content, event
)
self.validator.validate_new(event, self.config)
@ -1299,7 +1295,7 @@ class EventCreationHandler:
# Validate a newly added alias or newly added alt_aliases.
original_alias = None
original_alt_aliases = [] # type: List[str]
original_alt_aliases: List[str] = []
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:

View file

@ -72,26 +72,26 @@ _SESSION_COOKIES = [
(b"oidc_session_no_samesite", b"HttpOnly"),
]
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
#: OpenID.Core sec 3.1.3.3.
Token = TypedDict(
"Token",
{
"access_token": str,
"token_type": str,
"id_token": Optional[str],
"refresh_token": Optional[str],
"expires_in": int,
"scope": Optional[str],
},
)
class Token(TypedDict):
access_token: str
token_type: str
id_token: Optional[str]
refresh_token: Optional[str]
expires_in: int
scope: Optional[str]
#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
#: there is no real point of doing this in our case.
JWK = Dict[str, str]
#: A JWK Set, as per RFC7517 sec 5.
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class JWKS(TypedDict):
keys: List[JWK]
class OidcHandler:
@ -105,9 +105,9 @@ class OidcHandler:
assert provider_confs
self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = {
self._providers: Dict[str, "OidcProvider"] = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
} # type: Dict[str, OidcProvider]
}
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
@ -178,7 +178,7 @@ class OidcHandler:
# are two.
for cookie_name, _ in _SESSION_COOKIES:
session = request.getCookie(cookie_name) # type: Optional[bytes]
session: Optional[bytes] = request.getCookie(cookie_name)
if session is not None:
break
else:
@ -255,7 +255,7 @@ class OidcError(Exception):
def __str__(self):
if self.error_description:
return "{}: {}".format(self.error, self.error_description)
return f"{self.error}: {self.error_description}"
return self.error
@ -277,7 +277,7 @@ class OidcProvider:
self._token_generator = token_generator
self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
self._callback_url: str = hs.config.oidc_callback_url
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
@ -290,7 +290,7 @@ class OidcProvider:
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
client_secret = None # type: Union[None, str, JwtClientSecret]
client_secret: Optional[Union[str, JwtClientSecret]] = None
if provider.client_secret:
client_secret = provider.client_secret
elif provider.client_secret_jwt_key:
@ -305,7 +305,7 @@ class OidcProvider:
provider.client_id,
client_secret,
provider.client_auth_method,
) # type: ClientAuth
)
self._client_auth_method = provider.client_auth_method
# cache of metadata for the identity provider (endpoint uris, mostly). This is
@ -324,7 +324,7 @@ class OidcProvider:
self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str
self._server_name: str = hs.config.server_name
# identifier for the external_ids table
self.idp_id = provider.idp_id
@ -639,7 +639,7 @@ class OidcProvider:
)
logger.warning(description)
# Body was still valid JSON. Might be useful to log it for debugging.
logger.warning("Code exchange response: {resp!r}".format(resp=resp))
logger.warning("Code exchange response: %r", resp)
raise OidcError("server_error", description)
return resp
@ -1217,10 +1217,12 @@ class OidcSessionData:
ui_auth_session_id = attr.ib(type=str)
UserAttributeDict = TypedDict(
"UserAttributeDict",
{"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
)
class UserAttributeDict(TypedDict):
localpart: Optional[str]
display_name: Optional[str]
emails: List[str]
C = TypeVar("C")
@ -1381,7 +1383,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
emails = [] # type: List[str]
emails: List[str] = []
email = render_template_field(self._config.email_template)
if email:
emails.append(email)
@ -1391,7 +1393,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
extras: Dict[str, str] = {}
for key, template in self._config.extra_attributes.items():
try:
extras[key] = template.render(user=userinfo).strip()

View file

@ -81,9 +81,9 @@ class PaginationHandler:
self._server_name = hs.hostname
self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set() # type: Set[str]
self._purges_in_progress_by_room: Set[str] = set()
# map from purge id to PurgeStatus
self._purges_by_id = {} # type: Dict[str, PurgeStatus]
self._purges_by_id: Dict[str, PurgeStatus] = {}
self._event_serializer = hs.get_event_client_serializer()
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime

View file

@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
# The number of ongoing syncs on this process, by user id.
# Empty if _presence_enabled is false.
self._user_to_num_current_syncs = {} # type: Dict[str, int]
self._user_to_num_current_syncs: Dict[str, int] = {}
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
# user_id -> last_sync_ms. Lists the users that have stopped syncing but
# we haven't notified the presence writer of that yet
self.users_going_offline = {} # type: Dict[str, int]
self.users_going_offline: Dict[str, int] = {}
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
self._set_state_client = ReplicationPresenceSetState.make_client(hs)
@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler):
# Set of users who have presence in the `user_to_current_state` that
# have not yet been persisted
self.unpersisted_users_changes = set() # type: Set[str]
self.unpersisted_users_changes: Set[str] = set()
hs.get_reactor().addSystemEventTrigger(
"before",
@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler):
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
self.user_to_num_current_syncs = {} # type: Dict[str, int]
self.user_to_num_current_syncs: Dict[str, int] = {}
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler):
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]]
self.external_process_last_updated_ms = {} # type: Dict[str, int]
self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
self.external_process_last_updated_ms: Dict[str, int] = {}
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
@ -1581,9 +1581,7 @@ class PresenceEventSource:
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
interested_and_updated_users = (
set()
) # type: Union[Set[str], FrozenSet[str]]
interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
if from_key:
# First get all users that have had a presence update
@ -1950,8 +1948,8 @@ async def get_interested_parties(
A 2-tuple of `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]`
"""
room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
users_to_states = {} # type: Dict[str, List[UserPresenceState]]
room_ids_to_states: Dict[str, List[UserPresenceState]] = {}
users_to_states: Dict[str, List[UserPresenceState]] = {}
for state in states:
room_ids = await store.get_rooms_for_user(state.user_id)
for room_id in room_ids:
@ -2063,12 +2061,12 @@ class PresenceFederationQueue:
# stream_id, destinations, user_ids)`. We don't store the full states
# for efficiency, and remote workers will already have the full states
# cached.
self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]]
self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
self._next_id = 1
# Map from instance name to current token
self._current_tokens = {} # type: Dict[str, int]
self._current_tokens: Dict[str, int] = {}
if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
@ -2168,7 +2166,7 @@ class PresenceFederationQueue:
# handle the case where `from_token` stream ID has already been dropped.
start_idx = max(from_token + 1 - self._next_id, -len(self._queue))
to_send = [] # type: List[Tuple[int, Tuple[str, str]]]
to_send: List[Tuple[int, Tuple[str, str]]] = []
limited = False
new_id = upto_token
for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
@ -2216,7 +2214,7 @@ class PresenceFederationQueue:
if not self._federation:
return
hosts_to_users = {} # type: Dict[str, Set[str]]
hosts_to_users: Dict[str, Set[str]] = {}
for row in rows:
hosts_to_users.setdefault(row.destination, set()).add(row.user_id)

View file

@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler):
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)
displayname_to_set = new_displayname # type: Optional[str]
displayname_to_set: Optional[str] = new_displayname
if new_displayname == "":
displayname_to_set = None
@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
avatar_url_to_set = new_avatar_url # type: Optional[str]
avatar_url_to_set: Optional[str] = new_avatar_url
if new_avatar_url == "":
avatar_url_to_set = None

View file

@ -98,8 +98,8 @@ class ReceiptsHandler(BaseHandler):
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier."""
min_batch_id = None # type: Optional[int]
max_batch_id = None # type: Optional[int]
min_batch_id: Optional[int] = None
max_batch_id: Optional[int] = None
for receipt in receipts:
res = await self.store.insert_receipt(

View file

@ -55,15 +55,12 @@ login_counter = Counter(
["guest", "auth_provider"],
)
LoginDict = TypedDict(
"LoginDict",
{
"device_id": str,
"access_token": str,
"valid_until_ms": Optional[int],
"refresh_token": Optional[str],
},
)
class LoginDict(TypedDict):
device_id: str
access_token: str
valid_until_ms: Optional[int]
refresh_token: Optional[str]
class RegistrationHandler(BaseHandler):
@ -77,6 +74,7 @@ class RegistrationHandler(BaseHandler):
self.identity_handler = self.hs.get_identity_handler()
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._account_validity_handler = hs.get_account_validity_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self._server_name = hs.hostname
@ -700,6 +698,10 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
# Only call the account validity module(s) on the main process, to avoid
# repeating e.g. database writes on all of the workers.
await self._account_validity_handler.on_user_registration(user_id)
async def register_device(
self,
user_id: str,

View file

@ -87,7 +87,7 @@ class RoomCreationHandler(BaseHandler):
self.config = hs.config
# Room state based off defined presets
self._presets_dict = {
self._presets_dict: Dict[str, Dict[str, Any]] = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": HistoryVisibility.SHARED,
@ -109,7 +109,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False,
"power_level_content_override": {},
},
} # type: Dict[str, Dict[str, Any]]
}
# Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items():
@ -127,9 +127,9 @@ class RoomCreationHandler(BaseHandler):
# If a user tries to update the same room multiple times in quick
# succession, only process the first attempt and return its result to
# subsequent requests
self._upgrade_response_cache = ResponseCache(
self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
) # type: ResponseCache[Tuple[str, str]]
)
self._server_notices_mxid = hs.config.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules()
@ -377,10 +377,10 @@ class RoomCreationHandler(BaseHandler):
if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
creation_content: JsonDict = {
"room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
} # type: JsonDict
}
# Check if old room was non-federatable
@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler):
else:
is_requester_admin = await self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create
# request.
event_allowed = await self.third_party_event_rules.on_create_room(
# Let the third party rules modify the room creation config if needed, or abort
# the room creation entirely with an exception.
await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin
)
if not event_allowed:
raise SynapseError(
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id
@ -936,7 +932,7 @@ class RoomCreationHandler(BaseHandler):
etype=EventTypes.PowerLevels, content=pl_content
)
else:
power_level_content = {
power_level_content: JsonDict = {
"users": {creator_id: 100},
"users_default": 0,
"events": {
@ -955,7 +951,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
} # type: JsonDict
}
if config["original_invitees_have_ops"]:
for invitee in invite_list:

View file

@ -48,12 +48,12 @@ class RoomListHandler(BaseHandler):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(
hs.get_clock(), "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
self.remote_response_cache = ResponseCache(
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
self.response_cache: ResponseCache[
Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
] = ResponseCache(hs.get_clock(), "room_list")
self.remote_response_cache: ResponseCache[
Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
async def get_local_public_room_list(
self,
@ -140,10 +140,10 @@ class RoomListHandler(BaseHandler):
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
bounds = (
bounds: Optional[Tuple[int, str]] = (
batch_token.last_joined_members,
batch_token.last_room_id,
) # type: Optional[Tuple[int, str]]
)
forwards = batch_token.direction_is_forward
has_batch_token = True
else:
@ -183,7 +183,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results]
response = {} # type: JsonDict
response: JsonDict = {}
num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
@ -384,7 +384,11 @@ class RoomListHandler(BaseHandler):
):
logger.debug("Falling back to locally-filtered /publicRooms")
else:
raise # Not an error that should trigger a fallback.
# Not an error that should trigger a fallback.
raise SynapseError(502, "Failed to fetch room list")
except RequestSendFailed:
# Not an error that should trigger a fallback.
raise SynapseError(502, "Failed to fetch room list")
# if we reach this point, then we fall back to the situation where
# we currently don't support searching across federation, so we have

View file

@ -83,7 +83,7 @@ class SamlHandler(BaseHandler):
self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
@ -372,7 +372,7 @@ class SamlHandler(BaseHandler):
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
"[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)
)
@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str:
return username
MXID_MAPPER_MAP = {
MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
} # type: Dict[str, Callable[[str], str]]
}
@attr.s

View file

@ -192,7 +192,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
historical_room_ids = [] # type: List[str]
historical_room_ids: List[str] = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id)
@ -216,9 +216,9 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event
allowed_events = []
# Holds result of grouping by room, if applicable
room_groups = {} # type: Dict[str, JsonDict]
room_groups: Dict[str, JsonDict] = {}
# Holds result of grouping by sender, if applicable
sender_group = {} # type: Dict[str, JsonDict]
sender_group: Dict[str, JsonDict] = {}
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
@ -262,7 +262,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id)
elif order_by == "recent":
room_events = [] # type: List[EventBase]
room_events: List[EventBase] = []
i = 0
pagination_token = batch_token

View file

@ -90,14 +90,14 @@ class SpaceSummaryHandler:
room_queue = deque((_RoomQueueEntry(room_id, ()),))
# rooms we have already processed
processed_rooms = set() # type: Set[str]
processed_rooms: Set[str] = set()
# events we have already processed. We don't necessarily have their event ids,
# so instead we key on (room id, state key)
processed_events = set() # type: Set[Tuple[str, str]]
processed_events: Set[Tuple[str, str]] = set()
rooms_result = [] # type: List[JsonDict]
events_result = [] # type: List[JsonDict]
rooms_result: List[JsonDict] = []
events_result: List[JsonDict] = []
while room_queue and len(rooms_result) < MAX_ROOMS:
queue_entry = room_queue.popleft()
@ -272,10 +272,10 @@ class SpaceSummaryHandler:
# the set of rooms that we should not walk further. Initialise it with the
# excluded-rooms list; we will add other rooms as we process them so that
# we do not loop.
processed_rooms = set(exclude_rooms) # type: Set[str]
processed_rooms: Set[str] = set(exclude_rooms)
rooms_result = [] # type: List[JsonDict]
events_result = [] # type: List[JsonDict]
rooms_result: List[JsonDict] = []
events_result: List[JsonDict] = []
while room_queue and len(rooms_result) < MAX_ROOMS:
room_id = room_queue.popleft()
@ -353,7 +353,7 @@ class SpaceSummaryHandler:
max_children = MAX_ROOMS_PER_SPACE
now = self._clock.time_msec()
events_result = [] # type: List[JsonDict]
events_result: List[JsonDict] = []
for edge_event in itertools.islice(child_events, max_children):
events_result.append(
await self._event_serializer.serialize_event(

View file

@ -202,10 +202,10 @@ class SsoHandler:
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
# a map from session id to session data
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {}
# map from idp_id to SsoIdentityProvider
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
self._identity_providers: Dict[str, SsoIdentityProvider] = {}
self._consent_at_registration = hs.config.consent.user_consent_at_registration
@ -296,7 +296,7 @@ class SsoHandler:
)
# if the client chose an IdP, use that
idp = None # type: Optional[SsoIdentityProvider]
idp: Optional[SsoIdentityProvider] = None
if idp_id:
idp = self._identity_providers.get(idp_id)
if not idp:
@ -669,9 +669,9 @@ class SsoHandler:
remote_user_id,
)
user_id_to_verify = await self._auth_handler.get_session_data(
user_id_to_verify: str = await self._auth_handler.get_session_data(
ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str
)
if not user_id:
logger.warning(
@ -793,7 +793,7 @@ class SsoHandler:
session.use_display_name = use_display_name
emails_from_idp = set(session.emails)
filtered_emails = set() # type: Set[str]
filtered_emails: Set[str] = set()
# we iterate through the list rather than just building a set conjunction, so
# that we can log attempts to use unknown addresses

View file

@ -49,7 +49,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream
self.pos = None # type: Optional[int]
self.pos: Optional[int] = None
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@ -131,10 +131,10 @@ class StatsHandler:
mapping from room/user ID to changes in the various fields.
"""
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
room_to_stats_deltas: Dict[str, CounterType[str]] = {}
user_to_stats_deltas: Dict[str, CounterType[str]] = {}
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
room_to_state_updates: Dict[str, Dict[str, Any]] = {}
for delta in deltas:
typ = delta["type"]
@ -164,7 +164,7 @@ class StatsHandler:
)
continue
event_content = {} # type: JsonDict
event_content: JsonDict = {}
if event_id is not None:
event = await self.store.get_event(event_id, allow_none=True)

View file

@ -279,12 +279,14 @@ class SyncHandler:
self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
self.lazy_loaded_members_cache: ExpiringCache[
Tuple[str, Optional[str]], LruCache[str, str]
] = ExpiringCache(
"lazy_loaded_members_cache",
self.clock,
max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
)
async def wait_for_sync_for_user(
self,
@ -441,7 +443,7 @@ class SyncHandler:
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
ephemeral_by_room = {} # type: JsonDict
ephemeral_by_room: JsonDict = {}
for event in typing:
# we want to exclude the room_id from the event, but modifying the
@ -503,7 +505,7 @@ class SyncHandler:
# We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to
# ensure that we always include current state in the timeline
current_state_ids = frozenset() # type: FrozenSet[str]
current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents):
current_state_ids_map = await self.store.get_current_state_ids(
room_id
@ -784,9 +786,9 @@ class SyncHandler:
def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]]
) -> LruCache[str, str]:
cache = self.lazy_loaded_members_cache.get(
cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get(
cache_key
) # type: Optional[LruCache[str, str]]
)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
@ -985,7 +987,7 @@ class SyncHandler:
if t[0] == EventTypes.Member:
cache.set(t[1], event_id)
state = {} # type: Dict[str, EventBase]
state: Dict[str, EventBase] = {}
if state_ids:
state = await self.store.get_events(list(state_ids.values()))
@ -1089,8 +1091,8 @@ class SyncHandler:
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
one_time_key_counts = {} # type: JsonDict
unused_fallback_key_types = [] # type: List[str]
one_time_key_counts: JsonDict = {}
unused_fallback_key_types: List[str] = []
if device_id:
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
@ -1438,7 +1440,7 @@ class SyncHandler:
)
if block_all_room_ephemeral:
ephemeral_by_room = {} # type: Dict[str, List[JsonDict]]
ephemeral_by_room: Dict[str, List[JsonDict]] = {}
else:
now_token, ephemeral_by_room = await self.ephemeral_by_room(
sync_result_builder,
@ -1469,7 +1471,7 @@ class SyncHandler:
# If there is ignored users account data and it matches the proper type,
# then use it.
ignored_users = frozenset() # type: FrozenSet[str]
ignored_users: FrozenSet[str] = frozenset()
if ignored_account_data:
ignored_users_data = ignored_account_data.get("ignored_users", {})
if isinstance(ignored_users_data, dict):
@ -1587,7 +1589,7 @@ class SyncHandler:
user_id, since_token.room_key, now_token.room_key
)
mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]]
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
@ -1600,7 +1602,7 @@ class SyncHandler:
logger.debug(
"Membership changes in %s: [%s]",
room_id,
", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)),
", ".join("%s (%s)" % (e.event_id, e.membership) for e in events),
)
non_joins = [e for e in events if e.membership != Membership.JOIN]
@ -1723,7 +1725,7 @@ class SyncHandler:
# This is all screaming out for a refactor, as the logic here is
# subtle and the moving parts numerous.
if leave_event.internal_metadata.is_out_of_band_membership():
batch_events = [leave_event] # type: Optional[List[EventBase]]
batch_events: Optional[List[EventBase]] = [leave_event]
else:
batch_events = None
@ -1972,7 +1974,7 @@ class SyncHandler:
room_id, batch, sync_config, since_token, now_token, full_state=full_state
)
summary = {} # type: Optional[JsonDict]
summary: Optional[JsonDict] = {}
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
@ -1996,7 +1998,7 @@ class SyncHandler:
)
if room_builder.rtype == "joined":
unread_notifications = {} # type: Dict[str, int]
unread_notifications: Dict[str, int] = {}
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,

View file

@ -68,11 +68,11 @@ class FollowerTypingHandler:
)
# map room IDs to serial numbers
self._room_serials = {} # type: Dict[str, int]
self._room_serials: Dict[str, int] = {}
# map room IDs to sets of users currently typing
self._room_typing = {} # type: Dict[str, Set[str]]
self._room_typing: Dict[str, Set[str]] = {}
self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
self._member_last_federation_poke: Dict[RoomMember, int] = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
@ -217,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
# clock time we expect to stop
self._member_typing_until = {} # type: Dict[RoomMember, int]
self._member_typing_until: Dict[RoomMember, int] = {}
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
@ -405,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler):
if last_id == current_id:
return [], current_id, False
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
last_id
) # type: Optional[Iterable[str]]
changed_rooms: Optional[
Iterable[str]
] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
if changed_rooms is None:
changed_rooms = self._room_serials

View file

@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
self.pos = None # type: Optional[int]
self.pos: Optional[int] = None
# Guard to ensure we only process deltas one at a time
self._is_processing = False

View file

@ -172,7 +172,7 @@ class ProxyAgent(_AgentBase):
"""
uri = uri.strip()
if not _VALID_URI.match(uri):
raise ValueError("Invalid URI {!r}".format(uri))
raise ValueError(f"Invalid URI {uri!r}")
parsed_uri = URI.fromBytes(uri)
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)

View file

@ -384,7 +384,7 @@ class SynapseRequest(Request):
# authenticated (e.g. and admin is puppetting a user) then we log both.
requester, authenticated_entity = self.get_authenticated_entity()
if authenticated_entity:
requester = "{}.{}".format(authenticated_entity, requester)
requester = f"{authenticated_entity}.{requester}"
self.site.access_logger.log(
log_level,

View file

@ -374,7 +374,7 @@ def init_tracer(hs: "HomeServer"):
config = JaegerConfig(
config=hs.config.jaeger_config,
service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
service_name=f"{hs.config.server_name} {hs.get_instance_name()}",
scope_manager=LogContextScopeManager(hs.config),
metrics_factory=PrometheusMetricsFactory(),
)

View file

@ -34,7 +34,7 @@ from twisted.web.resource import Resource
from synapse.util import caches
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
INF = float("inf")
@ -55,8 +55,8 @@ def floatToGoString(d):
# Go switches to exponents sooner than Python.
# We only need to care about positive values for le/quantile.
if d > 0 and dot > 6:
mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.")
return "{0}e+0{1}".format(mantissa, dot - 1)
mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.")
return f"{mantissa}e+0{dot - 1}"
return s
@ -65,7 +65,7 @@ def sample_line(line, name):
labelstr = "{{{0}}}".format(
",".join(
[
'{0}="{1}"'.format(
'{}="{}"'.format(
k,
v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""),
)
@ -78,10 +78,8 @@ def sample_line(line, name):
timestamp = ""
if line.timestamp is not None:
# Convert to milliseconds.
timestamp = " {0:d}".format(int(float(line.timestamp) * 1000))
return "{0}{1} {2}{3}\n".format(
name, labelstr, floatToGoString(line.value), timestamp
)
timestamp = f" {int(float(line.timestamp) * 1000):d}"
return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
def generate_latest(registry, emit_help=False):
@ -118,12 +116,12 @@ def generate_latest(registry, emit_help=False):
# Output in the old format for compatibility.
if emit_help:
output.append(
"# HELP {0} {1}\n".format(
"# HELP {} {}\n".format(
mname,
metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
)
)
output.append("# TYPE {0} {1}\n".format(mname, mtype))
output.append(f"# TYPE {mname} {mtype}\n")
om_samples: Dict[str, List[str]] = {}
for s in metric.samples:
@ -143,13 +141,13 @@ def generate_latest(registry, emit_help=False):
for suffix, lines in sorted(om_samples.items()):
if emit_help:
output.append(
"# HELP {0}{1} {2}\n".format(
"# HELP {}{} {}\n".format(
metric.name,
suffix,
metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
)
)
output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix))
output.append(f"# TYPE {metric.name}{suffix} gauge\n")
output.extend(lines)
# Get rid of the weird colon things while we're at it
@ -163,12 +161,12 @@ def generate_latest(registry, emit_help=False):
# Also output in the new format, if it's different.
if emit_help:
output.append(
"# HELP {0} {1}\n".format(
"# HELP {} {}\n".format(
mnewname,
metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
)
)
output.append("# TYPE {0} {1}\n".format(mnewname, mtype))
output.append(f"# TYPE {mnewname} {mtype}\n")
for s in metric.samples:
# Get rid of the OpenMetrics specific samples (we should already have

View file

@ -137,8 +137,7 @@ class _Collector:
_background_process_db_txn_duration,
_background_process_db_sched_duration,
):
for r in m.collect():
yield r
yield from m.collect()
REGISTRY.register(_Collector())

View file

@ -12,18 +12,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import email.utils
import logging
from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
)
import jinja2
from twisted.internet import defer
from twisted.web.resource import IResource
from synapse.events import EventBase
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeHtmlResource,
DirectServeJsonResource,
respond_with_html,
)
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, UserID, create_requester
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -33,7 +57,20 @@ This package defines the 'stable' API which can be used by extension modules whi
are loaded into Synapse.
"""
__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"]
__all__ = [
"errors",
"make_deferred_yieldable",
"parse_json_object_from_request",
"respond_with_html",
"run_in_background",
"cached",
"UserID",
"DatabasePool",
"LoggingTransaction",
"DirectServeHtmlResource",
"DirectServeJsonResource",
"ModuleApi",
]
logger = logging.getLogger(__name__)
@ -52,12 +89,28 @@ class ModuleApi:
self._server_name = hs.hostname
self._presence_stream = hs.get_event_sources().sources["presence"]
self._state = hs.get_state_handler()
self._clock: Clock = hs.get_clock()
self._send_email_handler = hs.get_send_email_handler()
try:
app_name = self._hs.config.email_app_name
self._from_string = self._hs.config.email_notif_from % {"app": app_name}
except (KeyError, TypeError):
# If substitution failed (which can happen if the string contains
# placeholders other than just "app", or if the type of the placeholder is
# not a string), fall back to the bare strings.
self._from_string = self._hs.config.email_notif_from
self._raw_from = email.utils.parseaddr(self._from_string)[1]
# We expose these as properties below in order to attach a helpful docstring.
self._http_client: SimpleHttpClient = hs.get_simple_http_client()
self._public_room_list_manager = PublicRoomListManager(hs)
self._spam_checker = hs.get_spam_checker()
self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
#################################################################################
# The following methods should only be called during the module's initialisation.
@ -67,6 +120,16 @@ class ModuleApi:
"""Registers callbacks for spam checking capabilities."""
return self._spam_checker.register_callbacks
@property
def register_account_validity_callbacks(self):
"""Registers callbacks for account validity capabilities."""
return self._account_validity_handler.register_account_validity_callbacks
@property
def register_third_party_rules_callbacks(self):
"""Registers callbacks for third party event rules capabilities."""
return self._third_party_event_rules.register_third_party_rules_callbacks
def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path.
@ -101,22 +164,56 @@ class ModuleApi:
"""
return self._public_room_list_manager
def get_user_by_req(self, req, allow_guest=False):
@property
def public_baseurl(self) -> str:
"""The configured public base URL for this homeserver."""
return self._hs.config.public_baseurl
@property
def email_app_name(self) -> str:
"""The application name configured in the homeserver's configuration."""
return self._hs.config.email.email_app_name
async def get_user_by_req(
self,
req: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
) -> Requester:
"""Check the access_token provided for a request
Args:
req (twisted.web.server.Request): Incoming HTTP request
allow_guest (bool): True if guest users should be allowed. If this
req: Incoming HTTP request
allow_guest: True if guest users should be allowed. If this
is False, and the access token is for a guest user, an
AuthError will be thrown
allow_expired: True if expired users should be allowed. If this
is False, and the access token is for an expired user, an
AuthError will be thrown
Returns:
twisted.internet.defer.Deferred[synapse.types.Requester]:
the requester for this request
The requester for this request
Raises:
synapse.api.errors.AuthError: if no user by that token exists,
InvalidClientCredentialsError: if no user by that token exists,
or the token is invalid.
"""
return self._auth.get_user_by_req(req, allow_guest)
return await self._auth.get_user_by_req(
req,
allow_guest,
allow_expired=allow_expired,
)
async def is_user_admin(self, user_id: str) -> bool:
"""Checks if a user is a server admin.
Args:
user_id: The Matrix ID of the user to check.
Returns:
True if the user is a server admin, False otherwise.
"""
return await self._store.is_server_admin(UserID.from_string(user_id))
def get_qualified_user_id(self, username):
"""Qualify a user id, if necessary
@ -134,6 +231,32 @@ class ModuleApi:
return username
return UserID(username, self._hs.hostname).to_string()
async def get_profile_for_user(self, localpart: str) -> ProfileInfo:
"""Look up the profile info for the user with the given localpart.
Args:
localpart: The localpart to look up profile information for.
Returns:
The profile information (i.e. display name and avatar URL).
"""
return await self._store.get_profileinfo(localpart)
async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
"""Look up the threepids (email addresses and phone numbers) associated with the
given Matrix user ID.
Args:
user_id: The Matrix user ID to look up threepids for.
Returns:
A list of threepids, each threepid being represented by a dictionary
containing a "medium" key which value is "email" for email addresses and
"msisdn" for phone numbers, and an "address" key which value is the
threepid's address.
"""
return await self._store.user_get_threepids(user_id)
def check_user_exists(self, user_id):
"""Check if user exists.
@ -464,6 +587,88 @@ class ModuleApi:
presence_events, destination
)
def looping_background_call(
self,
f: Callable,
msec: float,
*args,
desc: Optional[str] = None,
**kwargs,
):
"""Wraps a function as a background process and calls it repeatedly.
Waits `msec` initially before calling `f` for the first time.
Args:
f: The function to call repeatedly. f can be either synchronous or
asynchronous, and must follow Synapse's logcontext rules.
More info about logcontexts is available at
https://matrix-org.github.io/synapse/latest/log_contexts.html
msec: How long to wait between calls in milliseconds.
*args: Positional arguments to pass to function.
desc: The background task's description. Default to the function's name.
**kwargs: Key arguments to pass to function.
"""
if desc is None:
desc = f.__name__
if self._hs.config.run_background_tasks:
self._clock.looping_call(
run_as_background_process,
msec,
desc,
f,
*args,
**kwargs,
)
else:
logger.warning(
"Not running looping call %s as the configuration forbids it",
f,
)
async def send_mail(
self,
recipient: str,
subject: str,
html: str,
text: str,
):
"""Send an email on behalf of the homeserver.
Args:
recipient: The email address for the recipient.
subject: The email's subject.
html: The email's HTML content.
text: The email's text content.
"""
await self._send_email_handler.send_email(
email_address=recipient,
subject=subject,
app_name=self.email_app_name,
html=html,
text=text,
)
def read_templates(
self,
filenames: List[str],
custom_template_directory: Optional[str] = None,
) -> List[jinja2.Template]:
"""Read and load the content of the template files at the given location.
By default, Synapse will look for these templates in its configured template
directory, but another directory to search in can be provided.
Args:
filenames: The name of the template files to look for.
custom_template_directory: An additional directory to look for the files in.
Returns:
A list containing the loaded templates, with the orders matching the one of
the filenames parameter.
"""
return self._hs.config.read_templates(filenames, custom_template_directory)
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room

View file

@ -14,5 +14,9 @@
"""Exception types which are exposed as part of the stable module API"""
from synapse.api.errors import RedirectException, SynapseError # noqa: F401
from synapse.api.errors import ( # noqa: F401
InvalidClientCredentialsError,
RedirectException,
SynapseError,
)
from synapse.config._base import ConfigError # noqa: F401

View file

@ -62,10 +62,6 @@ class PusherPool:
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self._account_validity_enabled = (
hs.config.account_validity.account_validity_enabled
)
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
@ -89,6 +85,8 @@ class PusherPool:
# map from user id to app_id:pushkey to pusher
self.pushers: Dict[str, Dict[str, Pusher]] = {}
self._account_validity_handler = hs.get_account_validity_handler()
def start(self) -> None:
"""Starts the pushers off in a background process."""
if not self._should_start_pushers:
@ -238,12 +236,9 @@ class PusherPool:
for u in users_affected:
# Don't push if the user account has expired
if self._account_validity_enabled:
expired = await self.store.is_account_expired(
u, self.clock.time_msec()
)
if expired:
continue
expired = await self._account_validity_handler.is_user_expired(u)
if expired:
continue
if u in self.pushers:
for p in self.pushers[u].values():
@ -268,12 +263,9 @@ class PusherPool:
for u in users_affected:
# Don't push if the user account has expired
if self._account_validity_enabled:
expired = await self.store.is_account_expired(
u, self.clock.time_msec()
)
if expired:
continue
expired = await self._account_validity_handler.is_user_expired(u)
if expired:
continue
if u in self.pushers:
for p in self.pushers[u].values():

View file

@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
# Get the room ID from the identifier.
try:
remote_room_hosts = [
remote_room_hosts: Optional[List[str]] = [
x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]]
]
except Exception:
remote_room_hosts = None
room_id, remote_room_hosts = await self.resolve_room_id(
@ -659,9 +659,7 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
else:
event_filter = None

View file

@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor()
self.nonces = {} # type: Dict[str, int]
self.nonces: Dict[str, int] = {}
self.hs = hs
def _clear_old_nonces(self):
@ -560,16 +560,24 @@ class AccountValidityRenewServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
if self.account_activity_handler.on_legacy_admin_request_callback:
expiration_ts = await (
self.account_activity_handler.on_legacy_admin_request_callback(request)
)
else:
body = parse_json_object_from_request(request)
if "user_id" not in body:
raise SynapseError(400, "Missing property 'user_id' in the request body")
if "user_id" not in body:
raise SynapseError(
400,
"Missing property 'user_id' in the request body",
)
expiration_ts = await self.account_activity_handler.renew_account_for_user(
body["user_id"],
body.get("expiration_ts"),
not body.get("enable_renewal_emails", True),
)
expiration_ts = await self.account_activity_handler.renew_account_for_user(
body["user_id"],
body.get("expiration_ts"),
not body.get("enable_renewal_emails", True),
)
res = {"expiration_ts": expiration_ts}
return 200, res

View file

@ -44,19 +44,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
LoginResponse = TypedDict(
"LoginResponse",
{
"user_id": str,
"access_token": str,
"home_server": str,
"expires_in_ms": Optional[int],
"refresh_token": Optional[str],
"device_id": str,
"well_known": Optional[Dict[str, Any]],
},
total=False,
)
class LoginResponse(TypedDict, total=False):
user_id: str
access_token: str
home_server: str
expires_in_ms: Optional[int]
refresh_token: Optional[str]
device_id: str
well_known: Optional[Dict[str, Any]]
class LoginRestServlet(RestServlet):
@ -121,7 +116,7 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
sso_flow = {
sso_flow: JsonDict = {
"type": LoginRestServlet.SSO_TYPE,
"identity_providers": [
_get_auth_flow_dict_for_idp(
@ -129,7 +124,7 @@ class LoginRestServlet(RestServlet):
)
for idp in self._sso_handler.get_identity_providers().values()
],
} # type: JsonDict
}
if self._msc2858_enabled:
# backwards-compatibility support for clients which don't
@ -150,9 +145,7 @@ class LoginRestServlet(RestServlet):
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
flows.extend(
({"type": t} for t in self.auth_handler.get_supported_login_types())
)
flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
@ -447,7 +440,7 @@ def _get_auth_flow_dict_for_idp(
use_unstable_brands: whether we should use brand identifiers suitable
for the unstable API
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
@ -561,7 +554,7 @@ class SsoRedirectServlet(RestServlet):
finish_request(request)
return
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
args: Dict[bytes, List[bytes]] = request.args # type: ignore
client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
sso_url = await self._sso_handler.handle_redirect_request(
request,

View file

@ -783,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
limit = int(content.get("limit", 100)) # type: Optional[int]
limit: Optional[int] = int(content.get("limit", 100))
since_token = content.get("since", None)
search_filter = content.get("filter", None)
@ -929,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
@ -1044,9 +1042,7 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
else:
event_filter = None

View file

@ -14,7 +14,7 @@
import logging
from synapse.api.errors import AuthError, SynapseError
from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet
@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet):
)
async def on_POST(self, request):
if not self.account_validity_renew_by_email_enabled:
raise AuthError(
403, "Account renewal via email is disabled on this server."
)
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
await self.account_activity_handler.send_renewal_email_to_user(user_id)

View file

@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester, message_type, content["messages"]
)
response = (200, {}) # type: Tuple[int, dict]
response: Tuple[int, dict] = (200, {})
return response

View file

@ -117,7 +117,7 @@ class ConsentResource(DirectServeHtmlResource):
has_consented = False
public_version = username == ""
if not public_version:
args = request.args # type: Dict[bytes, List[bytes]]
args: Dict[bytes, List[bytes]] = request.args
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes)
@ -154,7 +154,7 @@ class ConsentResource(DirectServeHtmlResource):
"""
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
args = request.args # type: Dict[bytes, List[bytes]]
args: Dict[bytes, List[bytes]] = request.args
userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac)

View file

@ -97,7 +97,7 @@ class RemoteKey(DirectServeJsonResource):
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
query = {server.decode("ascii"): {}} # type: dict
query: dict = {server.decode("ascii"): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
@ -141,7 +141,7 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec()
# Note that the value is unused.
cache_misses = {} # type: Dict[str, Dict[str, int]]
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]

View file

@ -17,7 +17,7 @@ import PIL.Image
# check for JPEG support.
try:
PIL.Image._getdecoder("rgb", "jpeg", None)
except IOError as e:
except OSError as e:
if str(e).startswith("decoder jpeg not available"):
raise Exception(
"FATAL: jpeg codec not supported. Install pillow correctly! "
@ -32,7 +32,7 @@ except Exception:
# check for PNG support.
try:
PIL.Image._getdecoder("rgb", "zip", None)
except IOError as e:
except OSError as e:
if str(e).startswith("decoder zip not available"):
raise Exception(
"FATAL: zip codec not supported. Install pillow correctly! "

View file

@ -49,7 +49,7 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath = request.postpath # type: List[bytes] # type: ignore
postpath: List[bytes] = request.postpath # type: ignore
assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for

View file

@ -78,16 +78,16 @@ class MediaRepository:
Thumbnailer.set_limits(self.max_image_pixels)
self.primary_base_path = hs.config.media_store_path # type: str
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.primary_base_path: str = hs.config.media_store_path
self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
self.recently_accessed_locals = set() # type: Set[str]
self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
self.recently_accessed_locals: Set[str] = set()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@ -711,7 +711,7 @@ class MediaRepository:
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails = {} # type: Dict[Tuple[int, int, str], str]
thumbnails: Dict[Tuple[int, int, str], str] = {}
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)

View file

@ -191,7 +191,7 @@ class MediaStorage:
for provider in self.storage_providers:
for path in paths:
res = await provider.fetch(path, file_info) # type: Any
res: Any = await provider.fetch(path, file_info)
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
@ -233,7 +233,7 @@ class MediaStorage:
os.makedirs(dirname)
for provider in self.storage_providers:
res = await provider.fetch(path, file_info) # type: Any
res: Any = await provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(

View file

@ -169,12 +169,12 @@ class PreviewUrlResource(DirectServeJsonResource):
# memory cache mapping urls to an ObservableDeferred returning
# JSON-encoded OG metadata
self._cache = ExpiringCache(
self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache(
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR,
) # type: ExpiringCache[str, ObservableDeferred]
)
if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call(
@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead.
url_to_download = url # type: Optional[str]
url_to_download: Optional[str] = url
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {} # type: Dict[str, Optional[str]]
og: Dict[str, Optional[str]] = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss

View file

@ -61,11 +61,11 @@ class UploadResource(DirectServeJsonResource):
errcode=Codes.TOO_LARGE,
)
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
args: Dict[bytes, List[bytes]] = request.args # type: ignore
upload_name_bytes = parse_bytes_from_args(args, "filename")
if upload_name_bytes:
try:
upload_name = upload_name_bytes.decode("utf8") # type: Optional[str]
upload_name: Optional[str] = upload_name_bytes.decode("utf8")
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
@ -89,7 +89,7 @@ class UploadResource(DirectServeJsonResource):
# TODO(markjh): parse content-dispostion
try:
content = request.content # type: IO # type: ignore
content: IO = request.content # type: ignore
content_uri = await self.media_repo.create_content(
media_type, upload_name, content, content_length, requester.user
)

View file

@ -118,9 +118,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
use_display_name = parse_boolean(request, "use_display_name", default=False)
try:
emails_to_use = [
emails_to_use: List[str] = [
val.decode("utf-8") for val in request.args.get(b"use_email", [])
] # type: List[str]
]
except ValueError:
raise SynapseError(400, "Query parameter use_email must be utf-8")
except SynapseError as e:

View file

@ -907,7 +907,7 @@ class DatabasePool:
# The sort is to ensure that we don't rely on dictionary iteration
# order.
keys, vals = zip(
*[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
*(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i)
)
for k in keys:

View file

@ -203,9 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_messages_for_device", delete_messages_for_device_txn
)
log_kv(
{"message": "deleted {} messages for device".format(count), "count": count}
)
log_kv({"message": f"deleted {count} messages for device", "count": count})
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(

View file

@ -2010,10 +2010,6 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
events_by_room: Dict[str, List[EventBase]] = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("

View file

@ -27,8 +27,11 @@ from synapse.util import json_encoder
_DEFAULT_CATEGORY_ID = ""
_DEFAULT_ROLE_ID = ""
# A room in a group.
_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
class _RoomInGroup(TypedDict):
room_id: str
is_public: bool
class GroupServerWorkerStore(SQLBaseStore):
@ -92,6 +95,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_public": False # Whether this is a public room or not
}
"""
# TODO: Pagination
def _get_rooms_in_group_txn(txn):

View file

@ -316,6 +316,135 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
async def count_r30v2_users(self) -> Dict[str, int]:
"""
Counts the number of 30 day retained users, defined as users that:
- Appear more than once in the past 60 days
- Have more than 30 days between the most and least recent appearances that
occurred in the past 60 days.
(This is the second version of this metric, hence R30'v2')
Returns:
A mapping from client type to the number of 30-day retained users for that client.
The dict keys are:
- "all" (a combined number of users across any and all clients)
- "android" (Element Android)
- "ios" (Element iOS)
- "electron" (Element Desktop)
- "web" (any web application -- it's not possible to distinguish Element Web here)
"""
def _count_r30v2_users(txn):
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
one_day_from_now_in_secs = now + 86400
# This is the 'per-platform' count.
sql = """
SELECT
client_type,
count(client_type)
FROM
(
SELECT
user_id,
CASE
WHEN
LOWER(user_agent) LIKE '%%riot%%' OR
LOWER(user_agent) LIKE '%%element%%'
THEN CASE
WHEN
LOWER(user_agent) LIKE '%%electron%%'
THEN 'electron'
WHEN
LOWER(user_agent) LIKE '%%android%%'
THEN 'android'
WHEN
LOWER(user_agent) LIKE '%%ios%%'
THEN 'ios'
ELSE 'unknown'
END
WHEN
LOWER(user_agent) LIKE '%%mozilla%%' OR
LOWER(user_agent) LIKE '%%gecko%%'
THEN 'web'
ELSE 'unknown'
END as client_type
FROM
user_daily_visits
WHERE
timestamp > ?
AND
timestamp < ?
GROUP BY
user_id,
client_type
HAVING
max(timestamp) - min(timestamp) > ?
) AS temp
GROUP BY
client_type
;
"""
# We initialise all the client types to zero, so we get an explicit
# zero if they don't appear in the query results
results = {"ios": 0, "android": 0, "web": 0, "electron": 0}
txn.execute(
sql,
(
sixty_days_ago_in_secs * 1000,
one_day_from_now_in_secs * 1000,
thirty_days_in_secs * 1000,
),
)
for row in txn:
if row[0] == "unknown":
continue
results[row[0]] = row[1]
# This is the 'all users' count.
sql = """
SELECT COUNT(*) FROM (
SELECT
1
FROM
user_daily_visits
WHERE
timestamp > ?
AND
timestamp < ?
GROUP BY
user_id
HAVING
max(timestamp) - min(timestamp) > ?
) AS r30_users
"""
txn.execute(
sql,
(
sixty_days_ago_in_secs * 1000,
one_day_from_now_in_secs * 1000,
thirty_days_in_secs * 1000,
),
)
row = txn.fetchone()
if row is None:
results["all"] = 0
else:
results["all"] = row[0]
return results
return await self.db_pool.runInteraction(
"count_r30v2_users", _count_r30v2_users
)
def _get_start_of_day(self):
"""
Returns millisecond unixtime for start of UTC day.

View file

@ -649,7 +649,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_to_memberships = await self._get_joined_profiles_from_event_ids(
missing_member_event_ids
)
users_in_room.update((row for row in event_to_memberships.values() if row))
users_in_room.update(row for row in event_to_memberships.values() if row)
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:

View file

@ -639,7 +639,7 @@ def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
def executescript(txn: Cursor, schema_path: str) -> None:
with open(schema_path, "r") as f:
with open(schema_path) as f:
execute_statements_from_stream(txn, f)

View file

@ -577,10 +577,10 @@ class RoomStreamToken:
entries = []
for name, pos in self.instance_map.items():
instance_id = await store.get_id_for_instance(name)
entries.append("{}.{}".format(instance_id, pos))
entries.append(f"{instance_id}.{pos}")
encoded_map = "~".join(entries)
return "m{}~{}".format(self.stream, encoded_map)
return f"m{self.stream}~{encoded_map}"
else:
return "s%d" % (self.stream,)

View file

@ -90,8 +90,7 @@ def enumerate_leaves(node, depth):
yield node
else:
for n in node.values():
for m in enumerate_leaves(n, depth - 1):
yield m
yield from enumerate_leaves(n, depth - 1)
P = TypeVar("P")

View file

@ -138,7 +138,6 @@ def iterate_tree_cache_entry(d):
"""
if isinstance(d, TreeCacheNode):
for value_d in d.values():
for value in iterate_tree_cache_entry(value_d):
yield value
yield from iterate_tree_cache_entry(value_d)
else:
yield d

View file

@ -31,13 +31,13 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
# If pidfile already exists, we should read pid from there; to overwrite it, if
# locking will fail, because locking attempt somehow purges the file contents.
if os.path.isfile(pid_file):
with open(pid_file, "r") as pid_fh:
with open(pid_file) as pid_fh:
old_pid = pid_fh.read()
# Create a lockfile so that only one instance of this daemon is running at any time.
try:
lock_fh = open(pid_file, "w")
except IOError:
except OSError:
print("Unable to create the pidfile.")
sys.exit(1)
@ -45,7 +45,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
# Try to get an exclusive lock on the file. This will fail if another process
# has the file locked.
fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB)
except IOError:
except OSError:
print("Unable to lock on the pidfile.")
# We need to overwrite the pidfile if we got here.
#
@ -113,7 +113,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
try:
lock_fh.write("%s" % (os.getpid()))
lock_fh.flush()
except IOError:
except OSError:
logger.error("Unable to write pid to the pidfile.")
print("Unable to write pid to the pidfile.")
sys.exit(1)

View file

@ -96,7 +96,7 @@ async def filter_events_for_client(
if isinstance(ignored_users_dict, dict):
ignore_list = frozenset(ignored_users_dict.keys())
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
if filter_send_to_client:
room_ids = {e.room_id for e in events}
@ -353,7 +353,7 @@ async def filter_events_for_server(
)
if not check_history_visibility_only:
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.

View file

@ -1,9 +1,11 @@
import synapse
from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.rest.client.v1 import login, room
from tests import unittest
from tests.unittest import HomeserverTestCase
FIVE_MINUTES_IN_SECONDS = 300
ONE_DAY_IN_SECONDS = 86400
@ -151,3 +153,243 @@ class PhoneHomeTestCase(HomeserverTestCase):
# *Now* the user appears in R30.
r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
self.assertEqual(r30_results, {"all": 1, "unknown": 1})
class PhoneHomeR30V2TestCase(HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
def _advance_to(self, desired_time_secs: float):
now = self.hs.get_clock().time()
assert now < desired_time_secs
self.reactor.advance(desired_time_secs - now)
def make_homeserver(self, reactor, clock):
hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
# We don't want our tests to actually report statistics, so check
# that it's not enabled
assert not hs.config.report_stats
# This starts the needed data collection that we rely on to calculate
# R30v2 metrics.
start_phone_stats_home(hs)
return hs
def test_r30v2_minimum_usage(self):
"""
Tests the minimum amount of interaction necessary for the R30v2 metric
to consider a user 'retained'.
"""
# Register a user, log it in, create a room and send a message
user_id = self.register_user("u1", "secret!")
access_token = self.login("u1", "secret!")
room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
self.helper.send(room_id, "message", tok=access_token)
first_post_at = self.hs.get_clock().time()
# Give time for user_daily_visits table to be updated.
# (user_daily_visits is updated every 5 minutes using a looping call.)
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
store = self.hs.get_datastore()
# Check the R30 results do not count that user.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
# Advance 31 days.
# (R30v2 includes users with **more** than 30 days between the two visits,
# and user_daily_visits records the timestamp as the start of the day.)
self.reactor.advance(31 * ONE_DAY_IN_SECONDS)
# Also advance 5 minutes to let another user_daily_visits update occur
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
# (Make sure the user isn't somehow counted by this point.)
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
# Send a message (this counts as activity)
self.helper.send(room_id, "message2", tok=access_token)
# We have to wait a few minutes for the user_daily_visits table to
# be updated by a background process.
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
# *Now* the user is counted.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
# Advance to JUST under 60 days after the user's first post
self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS - 5)
# Check the user is still counted.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
# Advance into the next day. The user's first activity is now more than 60 days old.
self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS + 5)
# Check the user is now no longer counted in R30.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
def test_r30v2_user_must_be_retained_for_at_least_a_month(self):
"""
Tests that a newly-registered user must be retained for a whole month
before appearing in the R30v2 statistic, even if they post every day
during that time!
"""
# set a custom user-agent to impersonate Element/Android.
headers = (
(
"User-Agent",
"Element/1.1 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
),
)
# Register a user and send a message
user_id = self.register_user("u1", "secret!")
access_token = self.login("u1", "secret!", custom_headers=headers)
room_id = self.helper.create_room_as(
room_creator=user_id, tok=access_token, custom_headers=headers
)
self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
# Give time for user_daily_visits table to be updated.
# (user_daily_visits is updated every 5 minutes using a looping call.)
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
store = self.hs.get_datastore()
# Check the user does not contribute to R30 yet.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
for _ in range(30):
# This loop posts a message every day for 30 days
self.reactor.advance(ONE_DAY_IN_SECONDS - FIVE_MINUTES_IN_SECONDS)
self.helper.send(
room_id, "I'm still here", tok=access_token, custom_headers=headers
)
# give time for user_daily_visits to update
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
# Notice that the user *still* does not contribute to R30!
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
# advance yet another day with more activity
self.reactor.advance(ONE_DAY_IN_SECONDS)
self.helper.send(
room_id, "Still here!", tok=access_token, custom_headers=headers
)
# give time for user_daily_visits to update
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
# *Now* the user appears in R30.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0}
)
def test_r30v2_returning_dormant_users_not_counted(self):
"""
Tests that dormant users (users inactive for a long time) do not
contribute to R30v2 when they return for just a single day.
This is a key difference between R30 and R30v2.
"""
# set a custom user-agent to impersonate Element/iOS.
headers = (
(
"User-Agent",
"Riot/1.4 (iPhone; iOS 13; Scale/4.00)",
),
)
# Register a user and send a message
user_id = self.register_user("u1", "secret!")
access_token = self.login("u1", "secret!", custom_headers=headers)
room_id = self.helper.create_room_as(
room_creator=user_id, tok=access_token, custom_headers=headers
)
self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
# the user goes inactive for 2 months
self.reactor.advance(60 * ONE_DAY_IN_SECONDS)
# the user returns for one day, perhaps just to check out a new feature
self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
# Give time for user_daily_visits table to be updated.
# (user_daily_visits is updated every 5 minutes using a looping call.)
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
store = self.hs.get_datastore()
# Check that the user does not contribute to R30v2, even though it's been
# more than 30 days since registration.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
# Check that this is a situation where old R30 differs:
# old R30 DOES count this as 'retained'.
r30_results = self.get_success(store.count_r30_users())
self.assertEqual(r30_results, {"all": 1, "ios": 1})
# Now we want to check that the user will still be able to appear in
# R30v2 as long as the user performs some other activity between
# 30 and 60 days later.
self.reactor.advance(32 * ONE_DAY_IN_SECONDS)
self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
# (give time for tables to update)
self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
# Check the user now satisfies the requirements to appear in R30v2.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
)
# Advance to 59.5 days after the user's first R30v2-eligible activity.
self.reactor.advance(27.5 * ONE_DAY_IN_SECONDS)
# Check the user still appears in R30v2.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
)
# Advance to 60.5 days after the user's first R30v2-eligible activity.
self.reactor.advance(ONE_DAY_IN_SECONDS)
# Check the user no longer appears in R30v2.
r30_results = self.get_success(store.count_r30v2_users())
self.assertEqual(
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
)

View file

@ -16,17 +16,19 @@ from typing import Dict
from unittest.mock import Mock
from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import Requester, StateMap
from synapse.util.frozenutils import unfreeze
from tests import unittest
thread_local = threading.local()
class ThirdPartyRulesTestModule:
class LegacyThirdPartyRulesTestModule:
def __init__(self, config: Dict, module_api: ModuleApi):
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule:
return config
def current_rules_module() -> ThirdPartyRulesTestModule:
return thread_local.rules_module
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: ModuleApi):
super().__init__(config, module_api)
def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
return False
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: ModuleApi):
super().__init__(config, module_api)
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
d = event.get_dict()
content = unfreeze(event.content)
content["foo"] = "bar"
d["content"] = content
return d
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def default_config(self):
config = super().default_config()
config["third_party_event_rules"] = {
"module": __name__ + ".ThirdPartyRulesTestModule",
"config": {},
}
return config
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
load_legacy_third_party_event_rules(hs)
return hs
def prepare(self, reactor, clock, homeserver):
# Create a user and room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
# Some tests might prevent room creation on purpose.
try:
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
except Exception:
pass
def test_third_party_rules(self):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
return ev.type != "foo.bar.forbidden"
return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check)
current_rules_module().check_event_allowed = callback
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
callback
]
channel = self.make_request(
"PUT",
@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state):
ev.content = {"x": "y"}
return True
return True, None
current_rules_module().check_event_allowed = check
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event
channel = self.make_request(
@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
self.assertEqual(channel.result["code"], b"500", channel.result)
# check_event_allowed has some error handling, so it shouldn't 500 just because a
# module did something bad.
self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "x")
def test_modify_event(self):
"""The module can return a modified version of the event"""
@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
async def check(ev: EventBase, state):
d = ev.get_dict()
d["content"] = {"x": "y"}
return d
return True, d
current_rules_module().check_event_allowed = check
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event
channel = self.make_request(
@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"msgtype": "m.text",
"body": d["content"]["body"].upper(),
}
return d
return True, d
current_rules_module().check_event_allowed = check
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# Send an event, then edit it.
channel = self.make_request(
@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.assertEqual(ev["content"]["body"], "EDITED BODY")
def test_send_event(self):
"""Tests that the module can send an event into a room via the module api"""
"""Tests that a module can send an event into a room via the module api"""
content = {
"msgtype": "m.text",
"body": "Hello!",
@ -234,12 +271,59 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"sender": self.user_id,
}
event: EventBase = self.get_success(
current_rules_module().module_api.create_and_send_event_into_room(
event_dict
)
self.hs.get_module_api().create_and_send_event_into_room(event_dict)
)
self.assertEquals(event.sender, self.user_id)
self.assertEquals(event.room_id, self.room_id)
self.assertEquals(event.type, "m.room.message")
self.assertEquals(event.content, content)
@unittest.override_config(
{
"third_party_event_rules": {
"module": __name__ + ".LegacyChangeEvents",
"config": {},
}
}
)
def test_legacy_check_event_allowed(self):
"""Tests that the wrapper for legacy check_event_allowed callbacks works
correctly.
"""
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id,
{
"msgtype": "m.text",
"body": "Original body",
},
access_token=self.tok,
)
self.assertEqual(channel.result["code"], b"200", channel.result)
event_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertIn("foo", channel.json_body["content"].keys())
self.assertEqual(channel.json_body["content"]["foo"], "bar")
@unittest.override_config(
{
"third_party_event_rules": {
"module": __name__ + ".LegacyDenyNewRooms",
"config": {},
}
}
)
def test_legacy_on_create_room(self):
"""Tests that the wrapper for legacy on_create_room callbacks works
correctly.
"""
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)

View file

@ -19,7 +19,7 @@ import json
import re
import time
import urllib.parse
from typing import Any, Dict, Mapping, MutableMapping, Optional
from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
from unittest.mock import patch
import attr
@ -53,6 +53,9 @@ class RestHelper:
tok: str = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
) -> str:
"""
Create a room.
@ -87,6 +90,7 @@ class RestHelper:
"POST",
path,
json.dumps(content).encode("utf8"),
custom_headers=custom_headers,
)
assert channel.result["code"] == b"%d" % expect_code, channel.result
@ -175,14 +179,30 @@ class RestHelper:
self.auth_user_id = temp_id
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
def send(
self,
room_id,
body=None,
txn_id=None,
tok=None,
expect_code=200,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
):
if body is None:
body = "body_text_here"
content = {"msgtype": "m.text", "body": body}
return self.send_event(
room_id, "m.room.message", content, txn_id, tok, expect_code
room_id,
"m.room.message",
content,
txn_id,
tok,
expect_code,
custom_headers=custom_headers,
)
def send_event(
@ -193,6 +213,9 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@ -207,6 +230,7 @@ class RestHelper:
"PUT",
path,
json.dumps(content or {}).encode("utf8"),
custom_headers=custom_headers,
)
assert (

View file

@ -168,6 +168,7 @@ class StateTestCase(unittest.TestCase):
"get_state_handler",
"get_clock",
"get_state_resolution_handler",
"get_account_validity_handler",
"hostname",
]
)

View file

@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase):
user_id = channel.json_body["user_id"]
return user_id
def login(self, username, password, device_id=None):
def login(
self,
username,
password,
device_id=None,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
):
"""
Log in a user, and get an access token. Requires the Login API be
registered.
@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase):
body["device_id"] = device_id
channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
"POST",
"/_matrix/client/r0/login",
json.dumps(body).encode("utf8"),
custom_headers=custom_headers,
)
self.assertEqual(channel.code, 200, channel.result)