0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 16:13:50 +01:00

Merge remote-tracking branch 'origin/develop' into rav/drop_py35

This commit is contained in:
Richard van der Hoff 2021-04-08 18:30:38 +01:00
commit 9e167d9c53
67 changed files with 2183 additions and 826 deletions

View file

@ -1,10 +1,27 @@
Synapse 1.31.0rc1 (2021-03-30) Synapse 1.31.0 (2021-04-06)
============================== ===========================
**Note:** As announced in v1.25.0, and in line with the deprecation policy for platform dependencies, this is the last release to support Python 3.5 and PostgreSQL 9.5. Future versions of Synapse will require Python 3.6+ and PostgreSQL 9.6+. **Note:** As announced in v1.25.0, and in line with the deprecation policy for platform dependencies, this is the last release to support Python 3.5 and PostgreSQL 9.5. Future versions of Synapse will require Python 3.6+ and PostgreSQL 9.6+, as per our [deprecation policy](docs/deprecation_policy.md).
This is also the last release that the Synapse team will be publishing packages for Debian Stretch and Ubuntu Xenial. This is also the last release that the Synapse team will be publishing packages for Debian Stretch and Ubuntu Xenial.
Improved Documentation
----------------------
- Add a document describing the deprecation policy for platform dependencies. ([\#9723](https://github.com/matrix-org/synapse/issues/9723))
Internal Changes
----------------
- Revert using `dmypy run` in lint script. ([\#9720](https://github.com/matrix-org/synapse/issues/9720))
- Pin flake8-bugbear's version. ([\#9734](https://github.com/matrix-org/synapse/issues/9734))
Synapse 1.31.0rc1 (2021-03-30)
==============================
Features Features
-------- --------

View file

@ -38,6 +38,7 @@ There are 3 steps to follow under **Installation Instructions**.
- [URL previews](#url-previews) - [URL previews](#url-previews)
- [Troubleshooting Installation](#troubleshooting-installation) - [Troubleshooting Installation](#troubleshooting-installation)
## Choosing your server name ## Choosing your server name
It is important to choose the name for your server before you install Synapse, It is important to choose the name for your server before you install Synapse,

View file

@ -314,6 +314,15 @@ Testing with SyTest is recommended for verifying that changes related to the
Client-Server API are functioning correctly. See the `installation instructions Client-Server API are functioning correctly. See the `installation instructions
<https://github.com/matrix-org/sytest#installing>`_ for details. <https://github.com/matrix-org/sytest#installing>`_ for details.
Platform dependencies
=====================
Synapse uses a number of platform dependencies such as Python and PostgreSQL,
and aims to follow supported upstream versions. See the
`<docs/deprecation_policy.md>`_ document for more details.
Troubleshooting Troubleshooting
=============== ===============
@ -384,7 +393,12 @@ massive excess of outgoing federation requests (see `discussion
indicate that your server is also issuing far more outgoing federation indicate that your server is also issuing far more outgoing federation
requests than can be accounted for by your users' activity, this is a requests than can be accounted for by your users' activity, this is a
likely cause. The misbehavior can be worked around by setting likely cause. The misbehavior can be worked around by setting
``use_presence: false`` in the Synapse config file. the following in the Synapse config file:
.. code-block:: yaml
presence:
enabled: false
People can't accept room invitations from me People can't accept room invitations from me
-------------------------------------------- --------------------------------------------

1
changelog.d/8926.bugfix Normal file
View file

@ -0,0 +1 @@
Prevent `synapse_forward_extremities` and `synapse_excess_extremity_events` Prometheus metrics from initially reporting zero-values after startup.

1
changelog.d/9491.feature Normal file
View file

@ -0,0 +1 @@
Add a Synapse module for routing presence updates between users.

1
changelog.d/9654.feature Normal file
View file

@ -0,0 +1 @@
Include request information in structured logging output.

View file

@ -1 +0,0 @@
Revert using `dmypy run` in lint script.

1
changelog.d/9725.bugfix Normal file
View file

@ -0,0 +1 @@
Fix longstanding bug which caused `duplicate key value violates unique constraint "remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"` errors.

1
changelog.d/9730.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to expiring cache.

View file

@ -1 +0,0 @@
Pin flake8-bugbear's version.

1
changelog.d/9735.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.

1
changelog.d/9736.misc Normal file
View file

@ -0,0 +1 @@
Convert various testcases to `HomeserverTestCase`.

1
changelog.d/9743.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to federation handler and server.

1
changelog.d/9753.misc Normal file
View file

@ -0,0 +1 @@
Check that a `ConfigError` is raised, rather than simply `Exception`, when appropriate in homeserver config file generation tests.

1
changelog.d/9765.docker Normal file
View file

@ -0,0 +1 @@
Move opencontainers labels to the final Docker image such that users can inspect them.

1
changelog.d/9770.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where sharded federation senders could get stuck repeatedly querying the DB in a loop, using lots of CPU.

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.31.0) stable; urgency=medium
* New synapse release 1.31.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 06 Apr 2021 13:08:29 +0100
matrix-synapse-py3 (1.30.1) stable; urgency=medium matrix-synapse-py3 (1.30.1) stable; urgency=medium
* New synapse release 1.30.1. * New synapse release 1.30.1.

View file

@ -18,11 +18,6 @@ ARG PYTHON_VERSION=3.8
### ###
FROM docker.io/python:${PYTHON_VERSION}-slim as builder FROM docker.io/python:${PYTHON_VERSION}-slim as builder
LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse'
LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md'
LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git'
LABEL org.opencontainers.image.licenses='Apache-2.0'
# install the OS build deps # install the OS build deps
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
build-essential \ build-essential \
@ -66,6 +61,11 @@ RUN pip install --prefix="/install" --no-deps --no-warn-script-location /synapse
FROM docker.io/python:${PYTHON_VERSION}-slim FROM docker.io/python:${PYTHON_VERSION}-slim
LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse'
LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md'
LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git'
LABEL org.opencontainers.image.licenses='Apache-2.0'
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
curl \ curl \
gosu \ gosu \

View file

@ -0,0 +1,33 @@
Deprecation Policy for Platform Dependencies
============================================
Synapse has a number of platform dependencies, including Python and PostgreSQL.
This document outlines the policy towards which versions we support, and when we
drop support for versions in the future.
Policy
------
Synapse follows the upstream support life cycles for Python and PostgreSQL,
i.e. when a version reaches End of Life Synapse will withdraw support for that
version in future releases.
Details on the upstream support life cycles for Python and PostgreSQL are
documented at https://endoflife.date/python and
https://endoflife.date/postgresql.
Context
-------
It is important for system admins to have a clear understanding of the platform
requirements of Synapse and its deprecation policies so that they can
effectively plan upgrading their infrastructure ahead of time. This is
especially important in contexts where upgrading the infrastructure requires
auditing and approval from a security team, or where otherwise upgrading is a
long process.
By following the upstream support life cycles Synapse can ensure that its
dependencies continue to get security patches, while not requiring system admins
to constantly update their platform dependencies to the latest versions.

View file

@ -0,0 +1,235 @@
# Presence Router Module
Synapse supports configuring a module that can specify additional users
(local or remote) to should receive certain presence updates from local
users.
Note that routing presence via Application Service transactions is not
currently supported.
The presence routing module is implemented as a Python class, which will
be imported by the running Synapse.
## Python Presence Router Class
The Python class is instantiated with two objects:
* A configuration object of some type (see below).
* An instance of `synapse.module_api.ModuleApi`.
It then implements methods related to presence routing.
Note that one method of `ModuleApi` that may be useful is:
```python
async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None
```
which can be given a list of local or remote MXIDs to broadcast known, online user
presence to (for those users that the receiving user is considered interested in).
It does not include state for users who are currently offline, and it can only be
called on workers that support sending federation.
### Module structure
Below is a list of possible methods that can be implemented, and whether they are
required.
#### `parse_config`
```python
def parse_config(config_dict: dict) -> Any
```
**Required.** A static method that is passed a dictionary of config options, and
should return a validated config object. This method is described further in
[Configuration](#configuration).
#### `get_users_for_states`
```python
async def get_users_for_states(
self,
state_updates: Iterable[UserPresenceState],
) -> Dict[str, Set[UserPresenceState]]:
```
**Required.** An asynchronous method that is passed an iterable of user presence
state. This method can determine whether a given presence update should be sent to certain
users. It does this by returning a dictionary with keys representing local or remote
Matrix User IDs, and values being a python set
of `synapse.handlers.presence.UserPresenceState` instances.
Synapse will then attempt to send the specified presence updates to each user when
possible.
#### `get_interested_users`
```python
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]
```
**Required.** An asynchronous method that is passed a single Matrix User ID. This
method is expected to return the users that the passed in user may be interested in the
presence of. Returned users may be local or remote. The presence routed as a result of
what this method returns is sent in addition to the updates already sent between users
that share a room together. Presence updates are deduplicated.
This method should return a python set of Matrix User IDs, or the object
`synapse.events.presence_router.PresenceRouter.ALL_USERS` to indicate that the passed
user should receive presence information for *all* known users.
For clarity, if the user `@alice:example.org` is passed to this method, and the Set
`{"@bob:example.com", "@charlie:somewhere.org"}` is returned, this signifies that Alice
should receive presence updates sent by Bob and Charlie, regardless of whether these
users share a room.
### Example
Below is an example implementation of a presence router class.
```python
from typing import Dict, Iterable, Set, Union
from synapse.events.presence_router import PresenceRouter
from synapse.handlers.presence import UserPresenceState
from synapse.module_api import ModuleApi
class PresenceRouterConfig:
def __init__(self):
# Config options with their defaults
# A list of users to always send all user presence updates to
self.always_send_to_users = [] # type: List[str]
# A list of users to ignore presence updates for. Does not affect
# shared-room presence relationships
self.blacklisted_users = [] # type: List[str]
class ExamplePresenceRouter:
"""An example implementation of synapse.presence_router.PresenceRouter.
Supports routing all presence to a configured set of users, or a subset
of presence from certain users to members of certain rooms.
Args:
config: A configuration object.
module_api: An instance of Synapse's ModuleApi.
"""
def __init__(self, config: PresenceRouterConfig, module_api: ModuleApi):
self._config = config
self._module_api = module_api
@staticmethod
def parse_config(config_dict: dict) -> PresenceRouterConfig:
"""Parse a configuration dictionary from the homeserver config, do
some validation and return a typed PresenceRouterConfig.
Args:
config_dict: The configuration dictionary.
Returns:
A validated config object.
"""
# Initialise a typed config object
config = PresenceRouterConfig()
always_send_to_users = config_dict.get("always_send_to_users")
blacklisted_users = config_dict.get("blacklisted_users")
# Do some validation of config options... otherwise raise a
# synapse.config.ConfigError.
config.always_send_to_users = always_send_to_users
config.blacklisted_users = blacklisted_users
return config
async def get_users_for_states(
self,
state_updates: Iterable[UserPresenceState],
) -> Dict[str, Set[UserPresenceState]]:
"""Given an iterable of user presence updates, determine where each one
needs to go. Returned results will not affect presence updates that are
sent between users who share a room.
Args:
state_updates: An iterable of user presence state updates.
Returns:
A dictionary of user_id -> set of UserPresenceState that the user should
receive.
"""
destination_users = {} # type: Dict[str, Set[UserPresenceState]
# Ignore any updates for blacklisted users
desired_updates = set()
for update in state_updates:
if update.state_key not in self._config.blacklisted_users:
desired_updates.add(update)
# Send all presence updates to specific users
for user_id in self._config.always_send_to_users:
destination_users[user_id] = desired_updates
return destination_users
async def get_interested_users(
self,
user_id: str,
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.
Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
that this user should receive all incoming local and remote presence updates.
Note that this method will only be called for local users.
Args:
user_id: A user requesting presence updates.
Returns:
A set of user IDs to return additional presence updates for, or
PresenceRouter.ALL_USERS to return presence updates for all other users.
"""
if user_id in self._config.always_send_to_users:
return PresenceRouter.ALL_USERS
return set()
```
#### A note on `get_users_for_states` and `get_interested_users`
Both of these methods are effectively two different sides of the same coin. The logic
regarding which users should receive updates for other users should be the same
between them.
`get_users_for_states` is called when presence updates come in from either federation
or local users, and is used to either direct local presence to remote users, or to
wake up the sync streams of local users to collect remote presence.
In contrast, `get_interested_users` is used to determine the users that presence should
be fetched for when a local user is syncing. This presence is then retrieved, before
being fed through `get_users_for_states` once again, with only the syncing user's
routing information pulled from the resulting dictionary.
Their routing logic should thus line up, else you may run into unintended behaviour.
## Configuration
Once you've crafted your module and installed it into the same Python environment as
Synapse, amend your homeserver config file with the following.
```yaml
presence:
routing_module:
module: my_module.ExamplePresenceRouter
config:
# Any configuration options for your module. The below is an example.
# of setting options for ExamplePresenceRouter.
always_send_to_users: ["@presence_gobbler:example.org"]
blacklisted_users:
- "@alice:example.com"
- "@bob:example.com"
...
```
The contents of `config` will be passed as a Python dictionary to the static
`parse_config` method of your class. The object returned by this method will
then be passed to the `__init__` method of your module as `config`.

View file

@ -82,9 +82,28 @@ pid_file: DATADIR/homeserver.pid
# #
#soft_file_limit: 0 #soft_file_limit: 0
# Set to false to disable presence tracking on this homeserver. # Presence tracking allows users to see the state (e.g online/offline)
# of other local and remote users.
# #
#use_presence: false presence:
# Uncomment to disable presence tracking on this homeserver. This option
# replaces the previous top-level 'use_presence' option.
#
#enabled: false
# Presence routers are third-party modules that can specify additional logic
# to where presence updates from users are routed.
#
presence_router:
# The custom module's class. Uncomment to use a custom presence router module.
#
#module: "my_custom_router.PresenceRouter"
# Configuration options of the custom module. Refer to your module's
# documentation for available options.
#
#config:
# example_option: 'something'
# Whether to require authentication to retrieve profile data (avatars, # Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to # display names) of other users through the client API. Defaults to

View file

@ -46,4 +46,4 @@ if [[ -n "$1" ]]; then
fi fi
# Run the tests! # Run the tests!
COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist -count=1 $EXTRA_COMPLEMENT_ARGS ./tests COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist,msc3083 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests

View file

@ -48,7 +48,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.31.0rc1" __version__ = "1.31.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View file

@ -281,6 +281,7 @@ class GenericWorkerPresence(BasePresenceHandler):
self.hs = hs self.hs = hs
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.presence_router = hs.get_presence_router()
self._presence_enabled = hs.config.use_presence self._presence_enabled = hs.config.use_presence
# The number of ongoing syncs on this process, by user id. # The number of ongoing syncs on this process, by user id.
@ -395,7 +396,7 @@ class GenericWorkerPresence(BasePresenceHandler):
return _user_syncing() return _user_syncing()
async def notify_from_replication(self, states, stream_id): async def notify_from_replication(self, states, stream_id):
parties = await get_interested_parties(self.store, states) parties = await get_interested_parties(self.store, self.presence_router, states)
room_ids_to_states, users_to_states = parties room_ids_to_states, users_to_states = parties
self.notifier.on_new_event( self.notifier.on_new_event(

View file

@ -27,6 +27,7 @@ import yaml
from netaddr import AddrFormatError, IPNetwork, IPSet from netaddr import AddrFormatError, IPNetwork, IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.stringutils import parse_and_validate_server_name
from ._base import Config, ConfigError from ._base import Config, ConfigError
@ -238,7 +239,20 @@ class ServerConfig(Config):
self.public_baseurl = config.get("public_baseurl") self.public_baseurl = config.get("public_baseurl")
# Whether to enable user presence. # Whether to enable user presence.
self.use_presence = config.get("use_presence", True) presence_config = config.get("presence") or {}
self.use_presence = presence_config.get("enabled")
if self.use_presence is None:
self.use_presence = config.get("use_presence", True)
# Custom presence router module
self.presence_router_module_class = None
self.presence_router_config = None
presence_router_config = presence_config.get("presence_router")
if presence_router_config:
(
self.presence_router_module_class,
self.presence_router_config,
) = load_module(presence_router_config, ("presence", "presence_router"))
# Whether to update the user directory or not. This should be set to # Whether to update the user directory or not. This should be set to
# false only if we are updating the user directory in a worker # false only if we are updating the user directory in a worker
@ -834,9 +848,28 @@ class ServerConfig(Config):
# #
#soft_file_limit: 0 #soft_file_limit: 0
# Set to false to disable presence tracking on this homeserver. # Presence tracking allows users to see the state (e.g online/offline)
# of other local and remote users.
# #
#use_presence: false presence:
# Uncomment to disable presence tracking on this homeserver. This option
# replaces the previous top-level 'use_presence' option.
#
#enabled: false
# Presence routers are third-party modules that can specify additional logic
# to where presence updates from users are routed.
#
presence_router:
# The custom module's class. Uncomment to use a custom presence router module.
#
#module: "my_custom_router.PresenceRouter"
# Configuration options of the custom module. Refer to your module's
# documentation for available options.
#
#config:
# example_option: 'something'
# Whether to require authentication to retrieve profile data (avatars, # Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to # display names) of other users through the client API. Defaults to

View file

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
from synapse.api.presence import UserPresenceState
if TYPE_CHECKING:
from synapse.server import HomeServer
class PresenceRouter:
"""
A module that the homeserver will call upon to help route user presence updates to
additional destinations. If a custom presence router is configured, calls will be
passed to that instead.
"""
ALL_USERS = "ALL"
def __init__(self, hs: "HomeServer"):
self.custom_presence_router = None
# Check whether a custom presence router module has been configured
if hs.config.presence_router_module_class:
# Initialise the module
self.custom_presence_router = hs.config.presence_router_module_class(
config=hs.config.presence_router_config, module_api=hs.get_module_api()
)
# Ensure the module has implemented the required methods
required_methods = ["get_users_for_states", "get_interested_users"]
for method_name in required_methods:
if not hasattr(self.custom_presence_router, method_name):
raise Exception(
"PresenceRouter module '%s' must implement all required methods: %s"
% (
hs.config.presence_router_module_class.__name__,
", ".join(required_methods),
)
)
async def get_users_for_states(
self,
state_updates: Iterable[UserPresenceState],
) -> Dict[str, Set[UserPresenceState]]:
"""
Given an iterable of user presence updates, determine where each one
needs to go.
Args:
state_updates: An iterable of user presence state updates.
Returns:
A dictionary of user_id -> set of UserPresenceState, indicating which
presence updates each user should receive.
"""
if self.custom_presence_router is not None:
# Ask the custom module
return await self.custom_presence_router.get_users_for_states(
state_updates=state_updates
)
# Don't include any extra destinations for presence updates
return {}
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.
Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
that this user should receive all incoming local and remote presence updates.
Note that this method will only be called for local users, but can return users
that are local or remote.
Args:
user_id: A user requesting presence updates.
Returns:
A set of user IDs to return presence updates for, or ALL_USERS to return all
known updates.
"""
if self.custom_presence_router is not None:
# Ask the custom module for interested users
return await self.custom_presence_router.get_interested_users(
user_id=user_id
)
# A custom presence router is not defined.
# Don't report any additional interested users
return set()

View file

@ -102,7 +102,7 @@ class FederationClient(FederationBase):
max_len=1000, max_len=1000,
expiry_ms=120 * 1000, expiry_ms=120 * 1000,
reset_expiry_on_get=False, reset_expiry_on_get=False,
) ) # type: ExpiringCache[str, EventBase]
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""

View file

@ -739,22 +739,20 @@ class FederationServer(FederationBase):
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self): def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
async def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
): ) -> None:
ret = await self.handler.exchange_third_party_invite( await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed sender_user_id, target_user_id, room_id, signed
) )
return ret
async def on_exchange_third_party_invite_request(self, event_dict: Dict): async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
ret = await self.handler.on_exchange_third_party_invite_request(event_dict) await self.handler.on_exchange_third_party_invite_request(event_dict)
return ret
async def check_server_matches_acl(self, server_name: str, room_id: str): async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
"""Check if the given server is allowed by the server ACLs in the room """Check if the given server is allowed by the server ACLs in the room
Args: Args:
@ -878,7 +876,7 @@ class FederationHandlerRegistry:
def register_edu_handler( def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
): ) -> None:
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation EDU of the given type. federation EDU of the given type.
@ -897,7 +895,7 @@ class FederationHandlerRegistry:
def register_query_handler( def register_query_handler(
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
): ) -> None:
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation query of the given type. federation query of the given type.
@ -915,15 +913,17 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
def register_instance_for_edu(self, edu_type: str, instance_name: str): def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
"""Register that the EDU handler is on a different instance than master.""" """Register that the EDU handler is on a different instance than master."""
self._edu_type_to_instance[edu_type] = [instance_name] self._edu_type_to_instance[edu_type] = [instance_name]
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): def register_instances_for_edu(
self, edu_type: str, instance_names: List[str]
) -> None:
"""Register that the EDU handler is on multiple instances.""" """Register that the EDU handler is on multiple instances."""
self._edu_type_to_instance[edu_type] = instance_names self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict): async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.use_presence and edu_type == EduTypes.Presence: if not self.config.use_presence and edu_type == EduTypes.Presence:
return return

View file

@ -44,6 +44,7 @@ from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.events.presence_router import PresenceRouter
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -162,6 +163,7 @@ class FederationSender(AbstractFederationSender):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self._presence_router = None # type: Optional[PresenceRouter]
self._transaction_manager = TransactionManager(hs) self._transaction_manager = TransactionManager(hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -584,7 +586,22 @@ class FederationSender(AbstractFederationSender):
"""Given a list of states populate self.pending_presence_by_dest and """Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination poke to send a new transaction to each destination
""" """
hosts_and_states = await get_interested_remotes(self.store, states, self.state) # We pull the presence router here instead of __init__
# to prevent a dependency cycle:
#
# AuthHandler -> Notifier -> FederationSender
# -> PresenceRouter -> ModuleApi -> AuthHandler
if self._presence_router is None:
self._presence_router = self.hs.get_presence_router()
assert self._presence_router is not None
hosts_and_states = await get_interested_remotes(
self.store,
self._presence_router,
states,
self.state,
)
for destinations, states in hosts_and_states: for destinations, states in hosts_and_states:
for destination in destinations: for destination in destinations:
@ -717,16 +734,18 @@ class FederationSender(AbstractFederationSender):
self._catchup_after_startup_timer = None self._catchup_after_startup_timer = None
break break
last_processed = destinations_to_wake[-1]
destinations_to_wake = [ destinations_to_wake = [
d d
for d in destinations_to_wake for d in destinations_to_wake
if self._federation_shard_config.should_handle(self._instance_name, d) if self._federation_shard_config.should_handle(self._instance_name, d)
] ]
for last_processed in destinations_to_wake: for destination in destinations_to_wake:
logger.info( logger.info(
"Destination %s has outstanding catch-up, waking up.", "Destination %s has outstanding catch-up, waking up.",
last_processed, last_processed,
) )
self.wake_destination(last_processed) self.wake_destination(destination)
await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC) await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)

View file

@ -620,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id): async def on_PUT(self, origin, content, query, room_id):
content = await self.handler.on_exchange_third_party_invite_request(content) await self.handler.on_exchange_third_party_invite_request(content)
return 200, content return 200, {}
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):

View file

@ -631,7 +631,7 @@ class DeviceListUpdater:
max_len=10000, max_len=10000,
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
iterable=True, iterable=True,
) ) # type: ExpiringCache[str, Set[str]]
# Attempt to resync out of sync device lists every 30s. # Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False self._resync_retry_in_progress = False
@ -760,7 +760,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full """Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta. resync, or whether we have enough data that we can just apply the delta.
""" """
seen_updates = self._seen_updates.get(user_id, set()) seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)

View file

@ -38,7 +38,6 @@ from synapse.types import (
) )
from synapse.util import json_decoder, unwrapFirstError from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
# user_id -> list of updates waiting to be handled. # user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="signing_key_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)
async def incoming_signing_key_update( async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict self, origin: str, edu_content: JsonDict
) -> None: ) -> None:

View file

@ -21,7 +21,17 @@ import itertools
import logging import logging
from collections.abc import Container from collections.abc import Container
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: async def on_receive_pdu(
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
) -> None:
"""Process a PDU received via a federation /send/ transaction, or """Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events via backfill of missing prev_events
Args: Args:
origin (str): server which initiated the /send/ transaction. Will origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state. be used to fetch missing events or state.
pdu (FrozenEvent): received PDU pdu: received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if sent_to_us_directly: True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event. we pulled it as the result of a missing prev_event.
""" """
@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
await self._process_received_pdu(origin, pdu, state=state) await self._process_received_pdu(origin, pdu, state=state)
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
) -> None:
""" """
Args: Args:
origin (str): Origin of the pdu. Will be called to get the missing events origin: Origin of the pdu. Will be called to get the missing events
pdu: received pdu pdu: received pdu
prevs (set(str)): List of event ids which we are missing prevs: List of event ids which we are missing
min_depth (int): Minimum depth of events to return. min_depth: Minimum depth of events to return.
""" """
room_id = pdu.room_id room_id = pdu.room_id
@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
event: EventBase, event: EventBase,
state: Optional[Iterable[EventBase]], state: Optional[Iterable[EventBase]],
): ) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it """Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler. through the StateHandler.
@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
logger.exception("Failed to resync device for %s", sender) logger.exception("Failed to resync device for %s", sender)
@log_function @log_function
async def backfill(self, dest, room_id, limit, extremities): async def backfill(
self, dest: str, room_id: str, limit: int, extremities: List[str]
) -> List[EventBase]:
"""Trigger a backfill request to `dest` for the given `room_id` """Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side This will attempt to get more events from the remote. If the other side
@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
curr_state = await self.state_handler.get_current_state(room_id) curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state): def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state """Get joined domains from state
Args: Args:
state (dict[tuple, FrozenEvent]): State map from type/state state: State map from type/state key to event.
key to event.
Returns: Returns:
list[tuple[str, int]]: Returns a list of servers with the Returns a list of servers with the lowest depth of their joins.
lowest depth of their joins. Sorted by lowest depth first. Sorted by lowest depth first.
""" """
joined_users = [ joined_users = [
(state_key, int(event.depth)) (state_key, int(event.depth))
@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name domain for domain, depth in curr_domains if domain != self.server_name
] ]
async def try_backfill(domains): async def try_backfill(domains: List[str]) -> bool:
# TODO: Should we try multiple of these at a time? # TODO: Should we try multiple of these at a time?
for dom in domains: for dom in domains:
try: try:
@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
} }
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_extremeties_domains = get_domains_from_state(states[e_id])
success = await try_backfill( success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains] [
dom
for dom, _ in likely_extremeties_domains
if dom not in tried_domains
]
) )
if success: if success:
return True return True
tried_domains.update(dom for dom, _ in likely_domains) tried_domains.update(dom for dom, _ in likely_extremeties_domains)
return False return False
async def _get_events_and_persist( async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str] self, destination: str, room_id: str, events: Iterable[str]
): ) -> None:
"""Fetch the given events from a server, and persist them as outliers. """Fetch the given events from a server, and persist them as outliers.
This function *does not* recursively get missing auth events of the This function *does not* recursively get missing auth events of the
@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
event_infos, event_infos,
) )
def _sanity_check_event(self, ev): def _sanity_check_event(self, ev: EventBase) -> None:
""" """
Do some early sanity checks of a received event Do some early sanity checks of a received event
@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
or cascade of event fetches. or cascade of event fetches.
Args: Args:
ev (synapse.events.EventBase): event to be checked ev: event to be checked
Returns: None
Raises: Raises:
SynapseError if the event does not pass muster SynapseError if the event does not pass muster
@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
) )
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def send_invite(self, target_host, event): async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing. """Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution. Invites must be signed by the invitee's server before distribution.
@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue) run_in_background(self._handle_queued_pdus, room_queue)
async def _handle_queued_pdus(self, room_queue): async def _handle_queued_pdus(
self, room_queue: List[Tuple[EventBase, str]]
) -> None:
"""Process PDUs which got queued up while we were busy send_joining. """Process PDUs which got queued up while we were busy send_joining.
Args: Args:
room_queue (list[FrozenEvent, str]): list of PDUs to be processed room_queue: list of PDUs to be processed and the servers that sent them
and the servers that sent them
""" """
for p, origin in room_queue: for p, origin in room_queue:
try: try:
@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
return event return event
async def on_send_join_request(self, origin, pdu): async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
"""We have received a join event for a room. Fully process it and """We have received a join event for a room. Fully process it and
respond with the current state and auth chains. respond with the current state and auth chains.
""" """
@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request( async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion self, origin: str, event: EventBase, room_version: RoomVersion
): ) -> EventBase:
"""We've got an invite event. Process and persist it. Sign it. """We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event. Respond with the now signed event.
@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
return event return event
async def on_send_leave_request(self, origin, pdu): async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
""" We have received a leave event for a room. Fully process it.""" """ We have received a leave event for a room. Fully process it."""
event = pdu event = pdu
@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
else: else:
return None return None
async def get_min_depth_for_context(self, context): async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context) return await self.store.get_min_depth(context)
async def _handle_new_event( async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False self,
): origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
context = await self._prep_event( context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled origin, event, state=state, auth_events=auth_events, backfilled=backfilled
) )
@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
logger.warning("Soft-failing %r because %s", event, e) logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True event.internal_metadata.soft_failed = True
async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
):
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain:
try:
await self._handle_new_event(origin, e)
except AuthError:
pass
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
logger.debug("on_query_auth returning: %s", ret)
return ret
async def on_get_missing_events( async def on_get_missing_events(
self, origin, room_id, earliest_events, latest_events, limit self,
): origin: str,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin) in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
assumes that we have already processed all events in remote_auth assumes that we have already processed all events in remote_auth
Params: Params:
local_auth (list) local_auth
remote_auth (list) remote_auth
Returns: Returns:
dict dict
@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
@log_function @log_function
async def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
): ) -> None:
third_party_invite = {"signed": signed} third_party_invite = {"signed": signed}
event_dict = { event_dict = {
@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
await member_handler.send_membership_event(None, event, context) await member_handler.send_membership_event(None, event, context)
async def add_display_name_to_third_party_invite( async def add_display_name_to_third_party_invite(
self, room_version, event_dict, event, context self,
): room_version: str,
event_dict: JsonDict,
event: EventBase,
context: EventContext,
) -> Tuple[EventBase, EventContext]:
key = ( key = (
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"], event.content["third_party_invite"]["signed"]["token"],
@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
EventValidator().validate_new(event, self.config) EventValidator().validate_new(event, self.config)
return (event, context) return (event, context)
async def _check_signature(self, event, context): async def _check_signature(self, event: EventBase, context: EventContext) -> None:
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
Args: Args:
event (Event): The m.room.member event to check event: The m.room.member event to check
context (EventContext): context:
Raises: Raises:
AuthError: if signature didn't match any keys, or key has been AuthError: if signature didn't match any keys, or key has been
@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
raise last_exception raise last_exception
async def _check_key_revocation(self, public_key, url): async def _check_key_revocation(self, public_key: str, url: str) -> None:
""" """
Checks whether public_key has been revoked. Checks whether public_key has been revoked.
Args: Args:
public_key (str): base-64 encoded public key. public_key: base-64 encoded public key.
url (str): Key revocation URL. url: Key revocation URL.
Raises: Raises:
AuthError: if they key has been revoked. AuthError: if they key has been revoked.

View file

@ -25,7 +25,17 @@ The methods that define policy are:
import abc import abc
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple from typing import (
TYPE_CHECKING,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import ContextManager from typing_extensions import ContextManager
@ -34,6 +44,7 @@ import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
@ -42,7 +53,7 @@ from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.presence_router = hs.get_presence_router()
self._presence_enabled = hs.config.use_presence self._presence_enabled = hs.config.use_presence
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler):
""" """
stream_id, max_token = await self.store.update_presence(states) stream_id, max_token = await self.store.update_presence(states)
parties = await get_interested_parties(self.store, states) parties = await get_interested_parties(self.store, self.presence_router, states)
room_ids_to_states, users_to_states = parties room_ids_to_states, users_to_states = parties
self.notifier.on_new_event( self.notifier.on_new_event(
@ -1041,7 +1053,12 @@ class PresenceEventSource:
# #
# Presence -> Notifier -> PresenceEventSource -> Presence # Presence -> Notifier -> PresenceEventSource -> Presence
# #
# Same with get_module_api, get_presence_router
#
# AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
self.get_presence_handler = hs.get_presence_handler self.get_presence_handler = hs.get_presence_handler
self.get_module_api = hs.get_module_api
self.get_presence_router = hs.get_presence_router
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -1055,7 +1072,7 @@ class PresenceEventSource:
include_offline=True, include_offline=True,
explicit_room_id=None, explicit_room_id=None,
**kwargs **kwargs
): ) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are: # The process for getting presence events are:
# 1. Get the rooms the user is in. # 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms. # 2. Get the list of user in the rooms.
@ -1068,7 +1085,17 @@ class PresenceEventSource:
# We don't try and limit the presence updates by the current token, as # We don't try and limit the presence updates by the current token, as
# sending down the rare duplicate is not a concern. # sending down the rare duplicate is not a concern.
user_id = user.to_string()
stream_change_cache = self.store.presence_stream_cache
with Measure(self.clock, "presence.get_new_events"): with Measure(self.clock, "presence.get_new_events"):
if user_id in self.get_module_api()._send_full_presence_to_local_users:
# This user has been specified by a module to receive all current, online
# user presence. Removing from_key and setting include_offline to false
# will do effectively this.
from_key = None
include_offline = False
if from_key is not None: if from_key is not None:
from_key = int(from_key) from_key = int(from_key)
@ -1091,59 +1118,209 @@ class PresenceEventSource:
# doesn't return. C.f. #5503. # doesn't return. C.f. #5503.
return [], max_token return [], max_token
presence = self.get_presence_handler() # Figure out which other users this user should receive updates for
stream_change_cache = self.store.presence_stream_cache
users_interested_in = await self._get_interested_in(user, explicit_room_id) users_interested_in = await self._get_interested_in(user, explicit_room_id)
user_ids_changed = set() # type: Collection[str] # We have a set of users that we're interested in the presence of. We want to
changed = None # cross-reference that with the users that have actually changed their presence.
if from_key:
changed = stream_change_cache.get_all_entities_changed(from_key)
if changed is not None and len(changed) < 500: # Check whether this user should see all user updates
assert isinstance(user_ids_changed, set)
# For small deltas, its quicker to get all changes and then if users_interested_in == PresenceRouter.ALL_USERS:
# work out if we share a room or they're in our presence list # Provide presence state for all users
get_updates_counter.labels("stream").inc() presence_updates = await self._filter_all_presence_updates_for_user(
for other_user_id in changed: user_id, include_offline, from_key
if other_user_id in users_interested_in: )
user_ids_changed.add(other_user_id)
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
get_updates_counter.labels("full").inc()
if from_key: # Remove the user from the list of users to receive all presence
user_ids_changed = stream_change_cache.get_entities_changed( if user_id in self.get_module_api()._send_full_presence_to_local_users:
users_interested_in, from_key self.get_module_api()._send_full_presence_to_local_users.remove(
user_id
) )
return presence_updates, max_token
# Make mypy happy. users_interested_in should now be a set
assert not isinstance(users_interested_in, str)
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
interested_and_updated_users = (
set()
) # type: Union[Set[str], FrozenSet[str]]
if from_key:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)
# Cross-reference users we're interested in with those that have had updates.
# Use a slightly-optimised method for processing smaller sets of updates.
if updated_users is not None and len(updated_users) < 500:
# For small deltas, it's quicker to get all changes and then
# cross-reference with the users we're interested in
get_updates_counter.labels("stream").inc()
for other_user_id in updated_users:
if other_user_id in users_interested_in:
# mypy thinks this variable could be a FrozenSet as it's possibly set
# to one in the `get_entities_changed` call below, and `add()` is not
# method on a FrozenSet. That doesn't affect us here though, as
# `interested_and_updated_users` is clearly a set() above.
interested_and_updated_users.add(other_user_id) # type: ignore
else: else:
user_ids_changed = users_interested_in # Too many possible updates. Find all users we can see and check
# if any of them have changed.
get_updates_counter.labels("full").inc()
updates = await presence.current_state_for_users(user_ids_changed) interested_and_updated_users = (
stream_change_cache.get_entities_changed(
users_interested_in, from_key
)
)
else:
# No from_key has been specified. Return the presence for all users
# this user is interested in
interested_and_updated_users = users_interested_in
if include_offline: # Retrieve the current presence state for each user
return (list(updates.values()), max_token) users_to_state = await self.get_presence_handler().current_state_for_users(
else: interested_and_updated_users
return (
[s for s in updates.values() if s.state != PresenceState.OFFLINE],
max_token,
) )
presence_updates = list(users_to_state.values())
# Remove the user from the list of users to receive all presence
if user_id in self.get_module_api()._send_full_presence_to_local_users:
self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
if not include_offline:
# Filter out offline presence states
presence_updates = self._filter_offline_presence_state(presence_updates)
return presence_updates, max_token
async def _filter_all_presence_updates_for_user(
self,
user_id: str,
include_offline: bool,
from_key: Optional[int] = None,
) -> List[UserPresenceState]:
"""
Computes the presence updates a user should receive.
First pulls presence updates from the database. Then consults PresenceRouter
for whether any updates should be excluded by user ID.
Args:
user_id: The User ID of the user to compute presence updates for.
include_offline: Whether to include offline presence states from the results.
from_key: The minimum stream ID of updates to pull from the database
before filtering.
Returns:
A list of presence states for the given user to receive.
"""
if from_key:
# Only return updates since the last sync
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
from_key
)
if not updated_users:
updated_users = []
# Get the actual presence update for each change
users_to_state = await self.get_presence_handler().current_state_for_users(
updated_users
)
presence_updates = list(users_to_state.values())
if not include_offline:
# Filter out offline states
presence_updates = self._filter_offline_presence_state(presence_updates)
else:
users_to_state = await self.store.get_presence_for_all_users(
include_offline=include_offline
)
presence_updates = list(users_to_state.values())
# TODO: This feels wildly inefficient, and it's unfortunate we need to ask the
# module for information on a number of users when we then only take the info
# for a single user
# Filter through the presence router
users_to_state_set = await self.get_presence_router().get_users_for_states(
presence_updates
)
# We only want the mapping for the syncing user
presence_updates = list(users_to_state_set[user_id])
# Return presence information for all users
return presence_updates
def _filter_offline_presence_state(
self, presence_updates: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
"""Given an iterable containing user presence updates, return a list with any offline
presence states removed.
Args:
presence_updates: Presence states to filter
Returns:
A new list with any offline presence states removed.
"""
return [
update
for update in presence_updates
if update.state != PresenceState.OFFLINE
]
def get_current_key(self): def get_current_key(self):
return self.store.get_current_presence_token() return self.store.get_current_presence_token()
@cached(num_args=2, cache_context=True) @cached(num_args=2, cache_context=True)
async def _get_interested_in(self, user, explicit_room_id, cache_context): async def _get_interested_in(
self,
user: UserID,
explicit_room_id: Optional[str] = None,
cache_context: Optional[_CacheContext] = None,
) -> Union[Set[str], str]:
"""Returns the set of users that the given user should see presence """Returns the set of users that the given user should see presence
updates for updates for.
Args:
user: The user to retrieve presence updates for.
explicit_room_id: The users that are in the room will be returned.
Returns:
A set of user IDs to return presence updates for, or "ALL" to return all
known updates.
""" """
user_id = user.to_string() user_id = user.to_string()
users_interested_in = set() users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence users_interested_in.add(user_id) # So that we receive our own presence
# cache_context isn't likely to ever be None due to the @cached decorator,
# but we can't have a non-optional argument after the optional argument
# explicit_room_id either. Assert cache_context is not None so we can use it
# without mypy complaining.
assert cache_context
# Check with the presence router whether we should poll additional users for
# their presence information
additional_users = await self.get_presence_router().get_interested_users(
user.to_string()
)
if additional_users == PresenceRouter.ALL_USERS:
# If the module requested that this user see the presence updates of *all*
# users, then simply return that instead of calculating what rooms this
# user shares
return PresenceRouter.ALL_USERS
# Add the additional users from the router
users_interested_in.update(additional_users)
# Find the users who share a room with this user
users_who_share_room = await self.store.get_users_who_share_room_with_user( users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate user_id, on_invalidate=cache_context.invalidate
) )
@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
async def get_interested_parties( async def get_interested_parties(
store: DataStore, states: List[UserPresenceState] store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState]
) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
"""Given a list of states return which entities (rooms, users) """Given a list of states return which entities (rooms, users)
are interested in the given states. are interested in the given states.
Args: Args:
store store: The homeserver's data store.
states presence_router: A module for augmenting the destinations for presence updates.
states: A list of incoming user presence updates.
Returns: Returns:
A 2-tuple of `(room_ids_to_states, users_to_states)`, A 2-tuple of `(room_ids_to_states, users_to_states)`,
@ -1337,11 +1515,22 @@ async def get_interested_parties(
# Always notify self # Always notify self
users_to_states.setdefault(state.user_id, []).append(state) users_to_states.setdefault(state.user_id, []).append(state)
# Ask a presence routing module for any additional parties if one
# is loaded.
router_users_to_states = await presence_router.get_users_for_states(states)
# Update the dictionaries with additional destinations and state to send
for user_id, user_states in router_users_to_states.items():
users_to_states.setdefault(user_id, []).extend(user_states)
return room_ids_to_states, users_to_states return room_ids_to_states, users_to_states
async def get_interested_remotes( async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler store: DataStore,
presence_router: PresenceRouter,
states: List[UserPresenceState],
state_handler: StateHandler,
) -> List[Tuple[Collection[str], List[UserPresenceState]]]: ) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers """Given a list of presence states figure out which remote servers
should be sent which. should be sent which.
@ -1349,9 +1538,10 @@ async def get_interested_remotes(
All the presence states should be for local users only. All the presence states should be for local users only.
Args: Args:
store store: The homeserver's data store.
states presence_router: A module for augmenting the destinations for presence updates.
state_handler states: A list of incoming user presence updates.
state_handler:
Returns: Returns:
A list of 2-tuples of destinations and states, where for A list of 2-tuples of destinations and states, where for
@ -1363,7 +1553,9 @@ async def get_interested_remotes(
# First we look up the rooms each user is in (as well as any explicit # First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote # subscriptions), then for each distinct room we look up the remote
# hosts in those rooms. # hosts in those rooms.
room_ids_to_states, users_to_states = await get_interested_parties(store, states) room_ids_to_states, users_to_states = await get_interested_parties(
store, presence_router, states
)
for room_id, states in room_ids_to_states.items(): for room_id, states in room_ids_to_states.items():
hosts = await state_handler.get_current_hosts_in_room(room_id) hosts = await state_handler.get_current_hosts_in_room(room_id)

View file

@ -20,7 +20,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types from synapse import types
from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -29,6 +29,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
@ -178,6 +179,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
async def _can_join_without_invite(
self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
) -> bool:
"""
Check whether a user can join a room without an invite.
When joining a room with restricted joined rules (as defined in MSC3083),
the membership of spaces must be checked during join.
Args:
state_ids: The state of the room as it currently is.
room_version: The room version of the room being joined.
user_id: The user joining the room.
Returns:
True if the user can join the room, false otherwise.
"""
# This only applies to room versions which support the new join rule.
if not room_version.msc3083_join_rules:
return True
# If there's no join rule, then it defaults to public (so this doesn't apply).
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
return True
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self.store.get_event(join_rules_event_id)
if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
return True
# If allowed is of the wrong form, then only allow invited users.
allowed_spaces = join_rules_event.content.get("allow", [])
if not isinstance(allowed_spaces, list):
return False
# Get the list of joined rooms and see if there's an overlap.
joined_rooms = await self.store.get_rooms_for_user(user_id)
# Pull out the other room IDs, invalid data gets filtered.
for space in allowed_spaces:
if not isinstance(space, dict):
continue
space_id = space.get("space")
if not isinstance(space_id, str):
continue
# The user was joined to one of the spaces specified, they can join
# this room!
if space_id in joined_rooms:
return True
# The user was not in any of the required spaces.
return False
async def _local_membership_update( async def _local_membership_update(
self, self,
requester: Requester, requester: Requester,
@ -235,9 +292,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
newly_joined = True newly_joined = True
user_is_invited = False
if prev_member_event_id: if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
user_is_invited = prev_member_event.membership == Membership.INVITE
# If the member is not already in the room and is not accepting an invite,
# check if they should be allowed access via membership in a space.
if (
newly_joined
and not user_is_invited
and not await self._can_join_without_invite(
prev_state_ids, event.room_version, user_id
)
):
raise AuthError(
403,
"You do not belong to any of the required spaces to join this room.",
)
# Only rate-limit if the user actually joined the room, otherwise we'll end # Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates. # up blocking profile updates.

View file

@ -252,13 +252,13 @@ class SyncHandler:
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id) # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache( self.lazy_loaded_members_cache = ExpiringCache(
"lazy_loaded_members_cache", "lazy_loaded_members_cache",
self.clock, self.clock,
max_len=0, max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
) ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
async def wait_for_sync_for_user( async def wait_for_sync_for_user(
self, self,
@ -733,8 +733,10 @@ class SyncHandler:
def get_lazy_loaded_members_cache( def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]] self, cache_key: Tuple[str, Optional[str]]
) -> LruCache: ) -> LruCache[str, str]:
cache = self.lazy_loaded_members_cache.get(cache_key) cache = self.lazy_loaded_members_cache.get(
cache_key
) # type: Optional[LruCache[str, str]]
if cache is None: if cache is None:
logger.debug("creating LruCache for %r", cache_key) logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)

View file

@ -14,7 +14,7 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Optional, Type, Union from typing import Optional, Tuple, Type, Union
import attr import attr
from zope.interface import implementer from zope.interface import implementer
@ -26,7 +26,11 @@ from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig from synapse.config.server import ListenerConfig
from synapse.http import get_request_user_agent, redact_uri from synapse.http import get_request_user_agent, redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import (
ContextRequest,
LoggingContext,
PreserveLoggingContext,
)
from synapse.types import Requester from synapse.types import Requester
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,7 +67,7 @@ class SynapseRequest(Request):
# The requester, if authenticated. For federation requests this is the # The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object. # server name, for client requests this is the Requester object.
self.requester = None # type: Optional[Union[Requester, str]] self._requester = None # type: Optional[Union[Requester, str]]
# we can't yet create the logcontext, as we don't know the method. # we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext] self.logcontext = None # type: Optional[LoggingContext]
@ -93,6 +97,31 @@ class SynapseRequest(Request):
self.site.site_tag, self.site.site_tag,
) )
@property
def requester(self) -> Optional[Union[Requester, str]]:
return self._requester
@requester.setter
def requester(self, value: Union[Requester, str]) -> None:
# Store the requester, and update some properties based on it.
# This should only be called once.
assert self._requester is None
self._requester = value
# A logging context should exist by now (and have a ContextRequest).
assert self.logcontext is not None
assert self.logcontext.request is not None
(
requester,
authenticated_entity,
) = self.get_authenticated_entity()
self.logcontext.request.requester = requester
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self): def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq) return "%s-%i" % (self.get_method(), self.request_seq)
@ -126,13 +155,60 @@ class SynapseRequest(Request):
return self.method.decode("ascii") return self.method.decode("ascii")
return method return method
def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]:
"""
Get the "authenticated" entity of the request, which might be the user
performing the action, or a user being puppeted by a server admin.
Returns:
A tuple:
The first item is a string representing the user making the request.
The second item is a string or None representing the user who
authenticated when making this request. See
Requester.authenticated_entity.
"""
# Convert the requester into a string that we can log
if isinstance(self._requester, str):
return self._requester, None
elif isinstance(self._requester, Requester):
requester = self._requester.user.to_string()
authenticated_entity = self._requester.authenticated_entity
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we return both.
if self._requester.user.to_string() != authenticated_entity:
return requester, authenticated_entity
return requester, None
elif self._requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
return repr(self._requester), None # type: ignore[unreachable]
return None, None
def render(self, resrc): def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our # this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource. # case the Resource in question will normally be a JsonResource.
# create a LogContext for this request # create a LogContext for this request
request_id = self.get_request_id() request_id = self.get_request_id()
self.logcontext = LoggingContext(request_id, request=request_id) self.logcontext = LoggingContext(
request_id,
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
site_tag=self.site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
authenticated_entity=None,
method=self.get_method(),
url=self.get_redacted_uri(),
protocol=self.clientproto.decode("ascii", errors="replace"),
user_agent=get_request_user_agent(self),
),
)
# override the Server header which is set by twisted # override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string) self.setHeader("Server", self.site.server_version_string)
@ -277,25 +353,6 @@ class SynapseRequest(Request):
# to the client (nb may be negative) # to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time response_send_time = self.finish_time - self._processing_finished_time
# Convert the requester into a string that we can log
authenticated_entity = None
if isinstance(self.requester, str):
authenticated_entity = self.requester
elif isinstance(self.requester, Requester):
authenticated_entity = self.requester.authenticated_entity
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format(
authenticated_entity,
self.requester.user.to_string(),
)
elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
user_agent = get_request_user_agent(self, "-") user_agent = get_request_user_agent(self, "-")
code = str(self.code) code = str(self.code)
@ -305,6 +362,13 @@ class SynapseRequest(Request):
code += "!" code += "!"
log_level = logging.INFO if self._should_log_request() else logging.DEBUG log_level = logging.INFO if self._should_log_request() else logging.DEBUG
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we log both.
requester, authenticated_entity = self.get_authenticated_entity()
if authenticated_entity:
requester = "{}.{}".format(authenticated_entity, requester)
self.site.access_logger.log( self.site.access_logger.log(
log_level, log_level,
"%s - %s - {%s}" "%s - %s - {%s}"
@ -312,7 +376,7 @@ class SynapseRequest(Request):
' %sB %s "%s %s %s" "%s" [%d dbevts]', ' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
authenticated_entity, requester,
processing_time, processing_time,
response_send_time, response_send_time,
usage.ru_utime, usage.ru_utime,

View file

@ -22,7 +22,6 @@ them.
See doc/log_contexts.rst for details on how this works. See doc/log_contexts.rst for details on how this works.
""" """
import inspect import inspect
import logging import logging
import threading import threading
@ -30,6 +29,7 @@ import types
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import defer, threads from twisted.internet import defer, threads
@ -181,6 +181,29 @@ class ContextResourceUsage:
return res return res
@attr.s(slots=True)
class ContextRequest:
"""
A bundle of attributes from the SynapseRequest object.
This exists to:
* Avoid a cycle between LoggingContext and SynapseRequest.
* Be a single variable that can be passed from parent LoggingContexts to
their children.
"""
request_id = attr.ib(type=str)
ip_address = attr.ib(type=str)
site_tag = attr.ib(type=str)
requester = attr.ib(type=Optional[str])
authenticated_entity = attr.ib(type=Optional[str])
method = attr.ib(type=str)
url = attr.ib(type=str)
protocol = attr.ib(type=str)
user_agent = attr.ib(type=str)
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
@ -256,7 +279,7 @@ class LoggingContext:
self, self,
name: Optional[str] = None, name: Optional[str] = None,
parent_context: "Optional[LoggingContext]" = None, parent_context: "Optional[LoggingContext]" = None,
request: Optional[str] = None, request: Optional[ContextRequest] = None,
) -> None: ) -> None:
self.previous_context = current_context() self.previous_context = current_context()
self.name = name self.name = name
@ -281,7 +304,11 @@ class LoggingContext:
self.parent_context = parent_context self.parent_context = parent_context
if self.parent_context is not None: if self.parent_context is not None:
self.parent_context.copy_to(self) # we track the current request_id
self.request = self.parent_context.request
# we also track the current scope:
self.scope = self.parent_context.scope
if request is not None: if request is not None:
# the request param overrides the request from the parent context # the request param overrides the request from the parent context
@ -289,7 +316,7 @@ class LoggingContext:
def __str__(self) -> str: def __str__(self) -> str:
if self.request: if self.request:
return str(self.request) return self.request.request_id
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@classmethod @classmethod
@ -556,8 +583,23 @@ class LoggingContextFilter(logging.Filter):
# we end up in a death spiral of infinite loops, so let's check, for # we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake. # robustness' sake.
if context is not None: if context is not None:
# Logging is interested in the request. # Logging is interested in the request ID. Note that for backwards
record.request = context.request # type: ignore # compatibility this is stored as the "request" on the record.
record.request = str(context) # type: ignore
# Add some data from the HTTP request.
request = context.request
if request is None:
return True
record.ip_address = request.ip_address # type: ignore
record.site_tag = request.site_tag # type: ignore
record.requester = request.requester # type: ignore
record.authenticated_entity = request.authenticated_entity # type: ignore
record.method = request.method # type: ignore
record.url = request.url # type: ignore
record.protocol = request.protocol # type: ignore
record.user_agent = request.user_agent # type: ignore
return True return True
@ -630,8 +672,8 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
def nested_logging_context(suffix: str) -> LoggingContext: def nested_logging_context(suffix: str) -> LoggingContext:
"""Creates a new logging context as a child of another. """Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's The nested logging context will have a 'name' made up of the parent context's
request, plus the given suffix. name, plus the given suffix.
CPU/db usage stats will be added to the parent context's on exit. CPU/db usage stats will be added to the parent context's on exit.
@ -641,7 +683,7 @@ def nested_logging_context(suffix: str) -> LoggingContext:
# ... do stuff # ... do stuff
Args: Args:
suffix: suffix to add to the parent context's 'request'. suffix: suffix to add to the parent context's 'name'.
Returns: Returns:
LoggingContext: new logging context. LoggingContext: new logging context.
@ -653,11 +695,17 @@ def nested_logging_context(suffix: str) -> LoggingContext:
) )
parent_context = None parent_context = None
prefix = "" prefix = ""
request = None
else: else:
assert isinstance(curr_context, LoggingContext) assert isinstance(curr_context, LoggingContext)
parent_context = curr_context parent_context = curr_context
prefix = str(parent_context.request) prefix = str(parent_context.name)
return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix) request = parent_context.request
return LoggingContext(
prefix + "-" + suffix,
parent_context=parent_context,
request=request,
)
def preserve_fn(f): def preserve_fn(f):

View file

@ -214,7 +214,12 @@ class GaugeBucketCollector:
Prometheus, and optimise for that case. Prometheus, and optimise for that case.
""" """
__slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric") __slots__ = (
"_name",
"_documentation",
"_bucket_bounds",
"_metric",
)
def __init__( def __init__(
self, self,
@ -242,11 +247,16 @@ class GaugeBucketCollector:
if self._bucket_bounds[-1] != float("inf"): if self._bucket_bounds[-1] != float("inf"):
self._bucket_bounds.append(float("inf")) self._bucket_bounds.append(float("inf"))
self._metric = self._values_to_metric([]) # We initially set this to None. We won't report metrics until
# this has been initialised after a successful data update
self._metric = None # type: Optional[GaugeHistogramMetricFamily]
registry.register(self) registry.register(self)
def collect(self): def collect(self):
yield self._metric # Don't report metrics unless we've already collected some data
if self._metric is not None:
yield self._metric
def update_data(self, values: Iterable[float]): def update_data(self, values: Iterable[float]):
"""Update the data to be reported by the metric """Update the data to be reported by the metric

View file

@ -16,7 +16,7 @@
import logging import logging
import threading import threading
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Set from typing import TYPE_CHECKING, Dict, Optional, Set, Union
from prometheus_client.core import REGISTRY, Counter, Gauge from prometheus_client.core import REGISTRY, Counter, Gauge
@ -199,11 +199,11 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc() _background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc()
with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: with BackgroundProcessLoggingContext(desc, count) as context:
try: try:
ctx = noop_context_manager() ctx = noop_context_manager()
if bg_start_span: if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request}) ctx = start_active_span(desc, tags={"request_id": str(context)})
with ctx: with ctx:
return await maybe_awaitable(func(*args, **kwargs)) return await maybe_awaitable(func(*args, **kwargs))
except Exception: except Exception:
@ -242,13 +242,19 @@ class BackgroundProcessLoggingContext(LoggingContext):
processes. processes.
""" """
__slots__ = ["_proc"] __slots__ = ["_id", "_proc"]
def __init__(self, name: str, request: Optional[str] = None): def __init__(self, name: str, id: Optional[Union[int, str]] = None):
super().__init__(name, request=request) super().__init__(name)
self._id = id
self._proc = _BackgroundProcess(name, self) self._proc = _BackgroundProcess(name, self)
def __str__(self) -> str:
if self._id is not None:
return "%s-%s" % (self.name, self._id)
return "%s@%x" % (self.name, id(self))
def start(self, rusage: "Optional[resource._RUsage]"): def start(self, rusage: "Optional[resource._RUsage]"):
"""Log context has started running (again).""" """Log context has started running (again)."""

View file

@ -50,11 +50,20 @@ class ModuleApi:
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._auth_handler = auth_handler self._auth_handler = auth_handler
self._server_name = hs.hostname self._server_name = hs.hostname
self._presence_stream = hs.get_event_sources().sources["presence"]
# We expose these as properties below in order to attach a helpful docstring. # We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
self._public_room_list_manager = PublicRoomListManager(hs) self._public_room_list_manager = PublicRoomListManager(hs)
# The next time these users sync, they will receive the current presence
# state of all local users. Users are added by send_local_online_presence_to,
# and removed after a successful sync.
#
# We make this a private variable to deter modules from accessing it directly,
# though other classes in Synapse will still do so.
self._send_full_presence_to_local_users = set()
@property @property
def http_client(self): def http_client(self):
"""Allows making outbound HTTP requests to remote resources. """Allows making outbound HTTP requests to remote resources.
@ -385,6 +394,47 @@ class ModuleApi:
return event return event
async def send_local_online_presence_to(self, users: Iterable[str]) -> None:
"""
Forces the equivalent of a presence initial_sync for a set of local or remote
users. The users will receive presence for all currently online users that they
are considered interested in.
Updates to remote users will be sent immediately, whereas local users will receive
them on their next sync attempt.
Note that this method can only be run on the main or federation_sender worker
processes.
"""
if not self._hs.should_send_federation():
raise Exception(
"send_local_online_presence_to can only be run "
"on processes that send federation",
)
for user in users:
if self._hs.is_mine_id(user):
# Modify SyncHandler._generate_sync_entry_for_presence to call
# presence_source.get_new_events with an empty `from_key` if
# that user's ID were in a list modified by ModuleApi somewhere.
# That user would then get all presence state on next incremental sync.
# Force a presence initial_sync for this user next time
self._send_full_presence_to_local_users.add(user)
else:
# Retrieve presence state for currently online users that this user
# is considered interested in
presence_events, _ = await self._presence_stream.get_new_events(
UserID.from_string(user), from_key=None, include_offline=False
)
# Send to remote destinations
await make_deferred_yieldable(
# We pull the federation sender here as we can only do so on workers
# that support sending presence
self._hs.get_federation_sender().send_presence(presence_events)
)
class PublicRoomListManager: class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room """Contains methods for adding to, removing from and querying whether a room

View file

@ -184,8 +184,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a # a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus. # background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id self._logging_context = BackgroundProcessLoggingContext(
self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) "replication-conn", self.conn_id
)
def connectionMade(self): def connectionMade(self):
logger.info("[%s] Connection established", self.id()) logger.info("[%s] Connection established", self.id())

View file

@ -175,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource):
clock=self.clock, clock=self.clock,
# don't spider URLs more often than once an hour # don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR, expiry_ms=ONE_HOUR,
) ) # type: ExpiringCache[str, ObservableDeferred]
if self._worker_run_media_background_jobs: if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call( self._cleaner_loop = self.clock.looping_call(

View file

@ -51,6 +51,7 @@ from synapse.crypto import context_factory
from synapse.crypto.context_factory import RegularPolicyForHTTPS from synapse.crypto.context_factory import RegularPolicyForHTTPS
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.events.presence_router import PresenceRouter
from synapse.events.spamcheck import SpamChecker from synapse.events.spamcheck import SpamChecker
from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.events.third_party_rules import ThirdPartyEventRules
from synapse.events.utils import EventClientSerializer from synapse.events.utils import EventClientSerializer
@ -425,6 +426,10 @@ class HomeServer(metaclass=abc.ABCMeta):
else: else:
raise Exception("Workers cannot write typing") raise Exception("Workers cannot write typing")
@cache_in_self
def get_presence_router(self) -> PresenceRouter:
return PresenceRouter(self)
@cache_in_self @cache_in_self
def get_typing_handler(self) -> FollowerTypingHandler: def get_typing_handler(self) -> FollowerTypingHandler:
if self.config.worker.writers.typing == self.get_instance_name(): if self.config.worker.writers.typing == self.get_instance_name():

View file

@ -22,6 +22,7 @@ from typing import (
Callable, Callable,
DefaultDict, DefaultDict,
Dict, Dict,
FrozenSet,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -515,7 +516,7 @@ class StateResolutionHandler:
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True, iterable=True,
reset_expiry_on_get=True, reset_expiry_on_get=True,
) ) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
# #
# stuff for tracking time spent on state-res by room # stuff for tracking time spent on state-res by room
@ -536,7 +537,7 @@ class StateResolutionHandler:
state_groups_ids: Dict[int, StateMap[str]], state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore", state_res_store: "StateResolutionStore",
): ) -> _StateCacheEntry:
"""Resolves conflicts between a set of state groups """Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should Always generates a new state group (unless we hit the cache), so should

View file

@ -22,6 +22,9 @@ from synapse.storage.database import DatabasePool
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
"media_repository_drop_index_wo_method" "media_repository_drop_index_wo_method"
) )
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
"media_repository_drop_index_wo_method_2"
)
class MediaSortOrder(Enum): class MediaSortOrder(Enum):
@ -85,23 +88,35 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
unique=True, unique=True,
) )
# the original impl of _drop_media_index_without_method was broken (see
# https://github.com/matrix-org/synapse/issues/8649), so we replace the original
# impl with a no-op and run the fixed migration as
# media_repository_drop_index_wo_method_2.
self.db_pool.updates.register_noop_background_update(
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
)
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD, BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
self._drop_media_index_without_method, self._drop_media_index_without_method,
) )
async def _drop_media_index_without_method(self, progress, batch_size): async def _drop_media_index_without_method(self, progress, batch_size):
"""background update handler which removes the old constraints.
Note that this is only run on postgres.
"""
def f(txn): def f(txn):
txn.execute( txn.execute(
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
) )
txn.execute( txn.execute(
"ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key" "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"
) )
await self.db_pool.runInteraction("drop_media_indices_without_method", f) await self.db_pool.runInteraction("drop_media_indices_without_method", f)
await self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2
) )
return 1 return 1

View file

@ -0,0 +1,22 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- drop old constraints on remote_media_cache_thumbnails
--
-- This was originally part of 57.07, but it was done wrong, per
-- https://github.com/matrix-org/synapse/issues/8649, so we do it again.
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
(5911, 'media_repository_drop_index_wo_method_2', '{}', 'remote_media_repository_thumbnails_method_idx');

View file

@ -15,40 +15,50 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Generic, Optional, TypeVar, Union, overload
import attr
from typing_extensions import Literal
from synapse.config import cache as cache_config from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SENTINEL = object() SENTINEL = object() # type: Any
class ExpiringCache: T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")
class ExpiringCache(Generic[KT, VT]):
def __init__( def __init__(
self, self,
cache_name, cache_name: str,
clock, clock: Clock,
max_len=0, max_len: int = 0,
expiry_ms=0, expiry_ms: int = 0,
reset_expiry_on_get=False, reset_expiry_on_get: bool = False,
iterable=False, iterable: bool = False,
): ):
""" """
Args: Args:
cache_name (str): Name of this cache, used for logging. cache_name: Name of this cache, used for logging.
clock (Clock) clock
max_len (int): Max size of dict. If the dict grows larger than this max_len: Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0, then the oldest items get automatically evicted. Default is 0,
which indicates there is no max limit. which indicates there is no max limit.
expiry_ms (int): How long before an item is evicted from the cache expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get in milliseconds. Default is 0, indicating items never get
evicted based on time. evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for reset_expiry_on_get: If true, will reset the expiry time for
an item on access. Defaults to False. an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the iterable: If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries. sizes of all entries, rather than the number of entries.
""" """
self._cache_name = cache_name self._cache_name = cache_name
@ -62,7 +72,7 @@ class ExpiringCache:
self._expiry_ms = expiry_ms self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get self._reset_expiry_on_get = reset_expiry_on_get
self._cache = OrderedDict() self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]
self.iterable = iterable self.iterable = iterable
@ -79,12 +89,12 @@ class ExpiringCache:
self._clock.looping_call(f, self._expiry_ms / 2) self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value): def __setitem__(self, key: KT, value: VT) -> None:
now = self._clock.time_msec() now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value) self._cache[key] = _CacheEntry(now, value)
self.evict() self.evict()
def evict(self): def evict(self) -> None:
# Evict if there are now too many items # Evict if there are now too many items
while self._max_size and len(self) > self._max_size: while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False) _key, value = self._cache.popitem(last=False)
@ -93,7 +103,7 @@ class ExpiringCache:
else: else:
self.metrics.inc_evictions() self.metrics.inc_evictions()
def __getitem__(self, key): def __getitem__(self, key: KT) -> VT:
try: try:
entry = self._cache[key] entry = self._cache[key]
self.metrics.inc_hits() self.metrics.inc_hits()
@ -106,7 +116,7 @@ class ExpiringCache:
return entry.value return entry.value
def pop(self, key, default=SENTINEL): def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Removes and returns the value with the given key from the cache. """Removes and returns the value with the given key from the cache.
If the key isn't in the cache then `default` will be returned if If the key isn't in the cache then `default` will be returned if
@ -115,29 +125,40 @@ class ExpiringCache:
Identical functionality to `dict.pop(..)`. Identical functionality to `dict.pop(..)`.
""" """
value = self._cache.pop(key, default) value = self._cache.pop(key, SENTINEL)
# The key was not found.
if value is SENTINEL: if value is SENTINEL:
raise KeyError(key) if default is SENTINEL:
raise KeyError(key)
return default
return value return value.value
def __contains__(self, key): def __contains__(self, key: KT) -> bool:
return key in self._cache return key in self._cache
def get(self, key, default=None): @overload
def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
...
@overload
def get(self, key: KT, default: T) -> Union[VT, T]:
...
def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
try: try:
return self[key] return self[key]
except KeyError: except KeyError:
return default return default
def setdefault(self, key, value): def setdefault(self, key: KT, value: VT) -> VT:
try: try:
return self[key] return self[key]
except KeyError: except KeyError:
self[key] = value self[key] = value
return value return value
def _prune_cache(self): def _prune_cache(self) -> None:
if not self._expiry_ms: if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called # zero expiry time means don't expire. This should never get called
# since we have this check in start too. # since we have this check in start too.
@ -166,7 +187,7 @@ class ExpiringCache:
len(self), len(self),
) )
def __len__(self): def __len__(self) -> int:
if self.iterable: if self.iterable:
return sum(len(entry.value) for entry in self._cache.values()) return sum(len(entry.value) for entry in self._cache.values())
else: else:
@ -190,9 +211,7 @@ class ExpiringCache:
return False return False
@attr.s(slots=True)
class _CacheEntry: class _CacheEntry:
__slots__ = ["time", "value"] time = attr.ib(type=int)
value = attr.ib()
def __init__(self, time, value):
self.time = time
self.value = value

View file

@ -20,6 +20,7 @@ from io import StringIO
import yaml import yaml
from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from tests import unittest from tests import unittest
@ -35,9 +36,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_load_fails_if_server_name_missing(self): def test_load_fails_if_server_name_missing(self):
self.generate_config_and_remove_lines_containing("server_name") self.generate_config_and_remove_lines_containing("server_name")
with self.assertRaises(Exception): with self.assertRaises(ConfigError):
HomeServerConfig.load_config("", ["-c", self.file]) HomeServerConfig.load_config("", ["-c", self.file])
with self.assertRaises(Exception): with self.assertRaises(ConfigError):
HomeServerConfig.load_or_generate_config("", ["-c", self.file]) HomeServerConfig.load_or_generate_config("", ["-c", self.file])
def test_generates_and_loads_macaroon_secret_key(self): def test_generates_and_loads_macaroon_secret_key(self):

View file

@ -16,6 +16,7 @@ import time
from mock import Mock from mock import Mock
import attr
import canonicaljson import canonicaljson
import signedjson.key import signedjson.key
import signedjson.sign import signedjson.sign
@ -68,6 +69,11 @@ class MockPerspectiveServer:
signedjson.sign.sign_json(res, self.server_name, self.key) signedjson.sign.sign_json(res, self.server_name, self.key)
@attr.s(slots=True)
class FakeRequest:
id = attr.ib()
@logcontext_clean @logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase): class KeyringTestCase(unittest.HomeserverTestCase):
def check_context(self, val, expected): def check_context(self, val, expected):
@ -89,7 +95,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
first_lookup_deferred = Deferred() first_lookup_deferred = Deferred()
async def first_lookup_fetch(keys_to_fetch): async def first_lookup_fetch(keys_to_fetch):
self.assertEquals(current_context().request, "context_11") self.assertEquals(current_context().request.id, "context_11")
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
await make_deferred_yieldable(first_lookup_deferred) await make_deferred_yieldable(first_lookup_deferred)
@ -102,9 +108,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher.get_keys.side_effect = first_lookup_fetch mock_fetcher.get_keys.side_effect = first_lookup_fetch
async def first_lookup(): async def first_lookup():
with LoggingContext("context_11") as context_11: with LoggingContext("context_11", request=FakeRequest("context_11")):
context_11.request = "context_11"
res_deferreds = kr.verify_json_objects_for_server( res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
) )
@ -130,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should block rather than start a second call # should block rather than start a second call
async def second_lookup_fetch(keys_to_fetch): async def second_lookup_fetch(keys_to_fetch):
self.assertEquals(current_context().request, "context_12") self.assertEquals(current_context().request.id, "context_12")
return { return {
"server10": { "server10": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
@ -142,9 +146,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
second_lookup_state = [0] second_lookup_state = [0]
async def second_lookup(): async def second_lookup():
with LoggingContext("context_12") as context_12: with LoggingContext("context_12", request=FakeRequest("context_12")):
context_12.request = "context_12"
res_deferreds_2 = kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")] [("server10", json1, 0, "test")]
) )
@ -589,10 +591,7 @@ def get_key_id(key):
@defer.inlineCallbacks @defer.inlineCallbacks
def run_in_context(f, *args, **kwargs): def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx") as ctx: with LoggingContext("testctx"):
# we set the "request" prop to make it easier to follow what's going on in the
# logs.
ctx.request = "testctx"
rv = yield f(*args, **kwargs) rv = yield f(*args, **kwargs)
return rv return rv

View file

@ -0,0 +1,386 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from mock import Mock
import attr
from synapse.api.constants import EduTypes
from synapse.events.presence_router import PresenceRouter
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client.v1 import login, presence, room
from synapse.types import JsonDict, StreamToken, create_requester
from tests.handlers.test_sync import generate_sync_config
from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
@attr.s
class PresenceRouterTestConfig:
users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
class PresenceRouterTestModule:
def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
self._config = config
self._module_api = module_api
async def get_users_for_states(
self, state_updates: Iterable[UserPresenceState]
) -> Dict[str, Set[UserPresenceState]]:
users_to_state = {
user_id: set(state_updates)
for user_id in self._config.users_who_should_receive_all_presence
}
return users_to_state
async def get_interested_users(
self, user_id: str
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
if user_id in self._config.users_who_should_receive_all_presence:
return PresenceRouter.ALL_USERS
return set()
@staticmethod
def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
"""Parse a configuration dictionary from the homeserver config, do
some validation and return a typed PresenceRouterConfig.
Args:
config_dict: The configuration dictionary.
Returns:
A validated config object.
"""
# Initialise a typed config object
config = PresenceRouterTestConfig()
config.users_who_should_receive_all_presence = config_dict.get(
"users_who_should_receive_all_presence"
)
return config
class PresenceRouterTestCase(FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
presence.register_servlets,
]
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)
def prepare(self, reactor, clock, homeserver):
self.sync_handler = self.hs.get_sync_handler()
self.module_api = homeserver.get_module_api()
@override_config(
{
"presence": {
"presence_router": {
"module": __name__ + ".PresenceRouterTestModule",
"config": {
"users_who_should_receive_all_presence": [
"@presence_gobbler:test",
]
},
}
},
"send_federation": True,
}
)
def test_receiving_all_presence(self):
"""Test that a user that does not share a room with another other can receive
presence for them, due to presence routing.
"""
# Create a user who should receive all presence of others
self.presence_receiving_user_id = self.register_user(
"presence_gobbler", "monkey"
)
self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey")
# And two users who should not have any special routing
self.other_user_one_id = self.register_user("other_user_one", "monkey")
self.other_user_one_tok = self.login("other_user_one", "monkey")
self.other_user_two_id = self.register_user("other_user_two", "monkey")
self.other_user_two_tok = self.login("other_user_two", "monkey")
# Put the other two users in a room with each other
room_id = self.helper.create_room_as(
self.other_user_one_id, tok=self.other_user_one_tok
)
self.helper.invite(
room_id,
self.other_user_one_id,
self.other_user_two_id,
tok=self.other_user_one_tok,
)
self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok)
# User one sends some presence
send_presence_update(
self,
self.other_user_one_id,
self.other_user_one_tok,
"online",
"boop",
)
# Check that the presence receiving user gets user one's presence when syncing
presence_updates, sync_token = sync_presence(
self, self.presence_receiving_user_id
)
self.assertEqual(len(presence_updates), 1)
presence_update = presence_updates[0] # type: UserPresenceState
self.assertEqual(presence_update.user_id, self.other_user_one_id)
self.assertEqual(presence_update.state, "online")
self.assertEqual(presence_update.status_msg, "boop")
# Have all three users send presence
send_presence_update(
self,
self.other_user_one_id,
self.other_user_one_tok,
"online",
"user_one",
)
send_presence_update(
self,
self.other_user_two_id,
self.other_user_two_tok,
"online",
"user_two",
)
send_presence_update(
self,
self.presence_receiving_user_id,
self.presence_receiving_user_tok,
"online",
"presence_gobbler",
)
# Check that the presence receiving user gets everyone's presence
presence_updates, _ = sync_presence(
self, self.presence_receiving_user_id, sync_token
)
self.assertEqual(len(presence_updates), 3)
# But that User One only get itself and User Two's presence
presence_updates, _ = sync_presence(self, self.other_user_one_id)
self.assertEqual(len(presence_updates), 2)
found = False
for update in presence_updates:
if update.user_id == self.other_user_two_id:
self.assertEqual(update.state, "online")
self.assertEqual(update.status_msg, "user_two")
found = True
self.assertTrue(found)
@override_config(
{
"presence": {
"presence_router": {
"module": __name__ + ".PresenceRouterTestModule",
"config": {
"users_who_should_receive_all_presence": [
"@presence_gobbler1:test",
"@presence_gobbler2:test",
"@far_away_person:island",
]
},
}
},
"send_federation": True,
}
)
def test_send_local_online_presence_to_with_module(self):
"""Tests that send_local_presence_to_users sends local online presence to a set
of specified local and remote users, with a custom PresenceRouter module enabled.
"""
# Create a user who will send presence updates
self.other_user_id = self.register_user("other_user", "monkey")
self.other_user_tok = self.login("other_user", "monkey")
# And another two users that will also send out presence updates, as well as receive
# theirs and everyone else's
self.presence_receiving_user_one_id = self.register_user(
"presence_gobbler1", "monkey"
)
self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey")
self.presence_receiving_user_two_id = self.register_user(
"presence_gobbler2", "monkey"
)
self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey")
# Have all three users send some presence updates
send_presence_update(
self,
self.other_user_id,
self.other_user_tok,
"online",
"I'm online!",
)
send_presence_update(
self,
self.presence_receiving_user_one_id,
self.presence_receiving_user_one_tok,
"online",
"I'm also online!",
)
send_presence_update(
self,
self.presence_receiving_user_two_id,
self.presence_receiving_user_two_tok,
"unavailable",
"I'm in a meeting!",
)
# Mark each presence-receiving user for receiving all user presence
self.get_success(
self.module_api.send_local_online_presence_to(
[
self.presence_receiving_user_one_id,
self.presence_receiving_user_two_id,
]
)
)
# Perform a sync for each user
# The other user should only receive their own presence
presence_updates, _ = sync_presence(self, self.other_user_id)
self.assertEqual(len(presence_updates), 1)
presence_update = presence_updates[0] # type: UserPresenceState
self.assertEqual(presence_update.user_id, self.other_user_id)
self.assertEqual(presence_update.state, "online")
self.assertEqual(presence_update.status_msg, "I'm online!")
# Whereas both presence receiving users should receive everyone's presence updates
presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id)
self.assertEqual(len(presence_updates), 3)
presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id)
self.assertEqual(len(presence_updates), 3)
# Test that sending to a remote user works
remote_user_id = "@far_away_person:island"
# Note that due to the remote user being in our module's
# users_who_should_receive_all_presence config, they would have
# received user presence updates already.
#
# Thus we reset the mock, and try sending all online local user
# presence again
self.hs.get_federation_transport_client().send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
self.module_api.send_local_online_presence_to([remote_user_id])
)
# Check that the expected presence updates were sent
expected_users = [
self.other_user_id,
self.presence_receiving_user_one_id,
self.presence_receiving_user_two_id,
]
calls = (
self.hs.get_federation_transport_client().send_transaction.call_args_list
)
for call in calls:
federation_transaction = call.args[0] # type: Transaction
# Get the sent EDUs in this transaction
edus = federation_transaction.get_dict()["edus"]
for edu in edus:
# Make sure we're only checking presence-type EDUs
if edu["edu_type"] != EduTypes.Presence:
continue
# EDUs can contain multiple presence updates
for presence_update in edu["content"]["push"]:
# Check for presence updates that contain the user IDs we're after
expected_users.remove(presence_update["user_id"])
# Ensure that no offline states are being sent out
self.assertNotEqual(presence_update["presence"], "offline")
self.assertEqual(len(expected_users), 0)
def send_presence_update(
testcase: TestCase,
user_id: str,
access_token: str,
presence_state: str,
status_message: Optional[str] = None,
) -> JsonDict:
# Build the presence body
body = {"presence": presence_state}
if status_message:
body["status_msg"] = status_message
# Update the user's presence state
channel = testcase.make_request(
"PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token
)
testcase.assertEqual(channel.code, 200)
return channel.json_body
def sync_presence(
testcase: TestCase,
user_id: str,
since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]:
"""Perform a sync request for the given user and return the user presence updates
they've received, as well as the next_batch token.
This method assumes testcase.sync_handler points to the homeserver's sync handler.
Args:
testcase: The testcase that is currently being run.
user_id: The ID of the user to generate a sync response for.
since_token: An optional token to indicate from at what point to sync from.
Returns:
A tuple containing a list of presence updates, and the sync response's
next_batch token.
"""
requester = create_requester(user_id)
sync_config = generate_sync_config(requester.user.to_string())
sync_result = testcase.get_success(
testcase.sync_handler.wait_for_sync_for_user(
requester, sync_config, since_token
)
)
return sync_result.presence, sync_result.next_batch

View file

@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self): def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test" user_id1 = "@user1:test"
user_id2 = "@user2:test" user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1) sync_config = generate_sync_config(user_id1)
requester = create_requester(user_id1) requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time self.reactor.advance(100) # So we get not 0 time
@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = False self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2) sync_config = generate_sync_config(user_id2)
requester = create_requester(user_id2) requester = create_requester(user_id2)
e = self.get_failure( e = self.get_failure(
@ -69,11 +69,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
) )
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig( def generate_sync_config(user_id: str) -> SyncConfig:
user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), return SyncConfig(
filter_collection=DEFAULT_FILTER_COLLECTION, user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
is_guest=False, filter_collection=DEFAULT_FILTER_COLLECTION,
request_key="request_key", is_guest=False,
device_id="device_id", request_key="request_key",
) device_id="device_id",
)

View file

@ -12,15 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import logging import logging
from io import StringIO from io import BytesIO, StringIO
from mock import Mock, patch
from twisted.web.server import Request
from synapse.http.site import SynapseRequest
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging.context import LoggingContext, LoggingContextFilter from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin from tests.logging import LoggerCleanupMixin
from tests.server import FakeChannel
from tests.unittest import TestCase from tests.unittest import TestCase
@ -120,7 +125,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
handler.addFilter(LoggingContextFilter()) handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler) logger = self.get_logger(handler)
with LoggingContext(request="test"): with LoggingContext("name"):
logger.info("Hello there, %s!", "wally") logger.info("Hello there, %s!", "wally")
log = self.get_log_line() log = self.get_log_line()
@ -134,4 +139,61 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
] ]
self.assertCountEqual(log.keys(), expected_log_keys) self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["request"], "test") self.assertTrue(log["request"].startswith("name@"))
def test_with_request_context(self):
"""
Information from the logging context request should be added to the JSON response.
"""
handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
# A full request isn't needed here.
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
site.site_tag = "test-site"
site.server_version_string = "Server v1"
request = SynapseRequest(FakeChannel(site, None))
# Call requestReceived to finish instantiating the object.
request.content = BytesIO()
# Partially skip some of the internal processing of SynapseRequest.
request._started_processing = Mock()
request.request_metrics = Mock(spec=["name"])
with patch.object(Request, "render"):
request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
# Also set the requester to ensure the processing works.
request.requester = "@foo:test"
with LoggingContext(parent_context=request.logcontext):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
# The terse logger includes additional request information, if possible.
expected_log_keys = [
"log",
"level",
"namespace",
"request",
"ip_address",
"site_tag",
"requester",
"authenticated_entity",
"method",
"url",
"protocol",
"user_agent",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
self.assertTrue(log["request"].startswith("POST-"))
self.assertEqual(log["ip_address"], "127.0.0.1")
self.assertEqual(log["site_tag"], "test-site")
self.assertEqual(log["requester"], "@foo:test")
self.assertEqual(log["authenticated_entity"], "@foo:test")
self.assertEqual(log["method"], "POST")
self.assertEqual(log["url"], "/_matrix/client/versions")
self.assertEqual(log["protocol"], "1.1")
self.assertEqual(log["user_agent"], "")

View file

@ -14,25 +14,37 @@
# limitations under the License. # limitations under the License.
from mock import Mock from mock import Mock
from synapse.api.constants import EduTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, presence, room
from synapse.types import create_requester from synapse.types import create_requester
from tests.unittest import HomeserverTestCase from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import FederatingHomeserverTestCase, override_config
class ModuleApiTestCase(HomeserverTestCase): class ModuleApiTestCase(FederatingHomeserverTestCase):
servlets = [ servlets = [
admin.register_servlets, admin.register_servlets,
login.register_servlets, login.register_servlets,
room.register_servlets, room.register_servlets,
presence.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore() self.store = homeserver.get_datastore()
self.module_api = homeserver.get_module_api() self.module_api = homeserver.get_module_api()
self.event_creation_handler = homeserver.get_event_creation_handler() self.event_creation_handler = homeserver.get_event_creation_handler()
self.sync_handler = homeserver.get_sync_handler()
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)
def test_can_register_user(self): def test_can_register_user(self):
"""Tests that an external module can register a user""" """Tests that an external module can register a user"""
@ -205,3 +217,160 @@ class ModuleApiTestCase(HomeserverTestCase):
) )
) )
self.assertFalse(is_in_public_rooms) self.assertFalse(is_in_public_rooms)
# The ability to send federation is required by send_local_online_presence_to.
@override_config({"send_federation": True})
def test_send_local_online_presence_to(self):
"""Tests that send_local_presence_to_users sends local online presence to local users."""
# Create a user who will send presence updates
self.presence_receiver_id = self.register_user("presence_receiver", "monkey")
self.presence_receiver_tok = self.login("presence_receiver", "monkey")
# And another user that will send presence updates out
self.presence_sender_id = self.register_user("presence_sender", "monkey")
self.presence_sender_tok = self.login("presence_sender", "monkey")
# Put them in a room together so they will receive each other's presence updates
room_id = self.helper.create_room_as(
self.presence_receiver_id,
tok=self.presence_receiver_tok,
)
self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok)
# Presence sender comes online
send_presence_update(
self,
self.presence_sender_id,
self.presence_sender_tok,
"online",
"I'm online!",
)
# Presence receiver should have received it
presence_updates, sync_token = sync_presence(self, self.presence_receiver_id)
self.assertEqual(len(presence_updates), 1)
presence_update = presence_updates[0] # type: UserPresenceState
self.assertEqual(presence_update.user_id, self.presence_sender_id)
self.assertEqual(presence_update.state, "online")
# Syncing again should result in no presence updates
presence_updates, sync_token = sync_presence(
self, self.presence_receiver_id, sync_token
)
self.assertEqual(len(presence_updates), 0)
# Trigger sending local online presence
self.get_success(
self.module_api.send_local_online_presence_to(
[
self.presence_receiver_id,
]
)
)
# Presence receiver should have received online presence again
presence_updates, sync_token = sync_presence(
self, self.presence_receiver_id, sync_token
)
self.assertEqual(len(presence_updates), 1)
presence_update = presence_updates[0] # type: UserPresenceState
self.assertEqual(presence_update.user_id, self.presence_sender_id)
self.assertEqual(presence_update.state, "online")
# Presence sender goes offline
send_presence_update(
self,
self.presence_sender_id,
self.presence_sender_tok,
"offline",
"I slink back into the darkness.",
)
# Trigger sending local online presence
self.get_success(
self.module_api.send_local_online_presence_to(
[
self.presence_receiver_id,
]
)
)
# Presence receiver should *not* have received offline state
presence_updates, sync_token = sync_presence(
self, self.presence_receiver_id, sync_token
)
self.assertEqual(len(presence_updates), 0)
@override_config({"send_federation": True})
def test_send_local_online_presence_to_federation(self):
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates
self.presence_sender_id = self.register_user("presence_sender", "monkey")
self.presence_sender_tok = self.login("presence_sender", "monkey")
# And a room they're a part of
room_id = self.helper.create_room_as(
self.presence_sender_id,
tok=self.presence_sender_tok,
)
# Mark them as online
send_presence_update(
self,
self.presence_sender_id,
self.presence_sender_tok,
"online",
"I'm online!",
)
# Make up a remote user to send presence to
remote_user_id = "@far_away_person:island"
# Create a join membership event for the remote user into the room.
# This allows presence information to flow from one user to the other.
self.get_success(
inject_member_event(
self.hs,
room_id,
sender=remote_user_id,
target=remote_user_id,
membership="join",
)
)
# The remote user would have received the existing room members' presence
# when they joined the room.
#
# Thus we reset the mock, and try sending online local user
# presence again
self.hs.get_federation_transport_client().send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
self.module_api.send_local_online_presence_to([remote_user_id])
)
# Check that a presence update was sent as part of a federation transaction
found_update = False
calls = (
self.hs.get_federation_transport_client().send_transaction.call_args_list
)
for call in calls:
federation_transaction = call.args[0] # type: Transaction
# Get the sent EDUs in this transaction
edus = federation_transaction.get_dict()["edus"]
for edu in edus:
# Make sure we're only checking presence-type EDUs
if edu["edu_type"] != EduTypes.Presence:
continue
# EDUs can contain multiple presence updates
for presence_update in edu["content"]["push"]:
if presence_update["user_id"] == self.presence_sender_id:
found_update = True
self.assertTrue(found_update)

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,32 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import synapse.api.errors import synapse.api.errors
import tests.unittest from tests.unittest import HomeserverTestCase
import tests.utils
class DeviceStoreTestCase(tests.unittest.TestCase): class DeviceStoreTestCase(HomeserverTestCase):
def __init__(self, *args, **kwargs): def prepare(self, reactor, clock, hs):
super().__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_store_new_device(self): def test_store_new_device(self):
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device_id", "display_name") self.store.store_device("user_id", "device_id", "display_name")
) )
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"user_id": "user_id", "user_id": "user_id",
@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res, res,
) )
@defer.inlineCallbacks
def test_get_devices_by_user(self): def test_get_devices_by_user(self):
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device1", "display_name 1") self.store.store_device("user_id", "device1", "display_name 1")
) )
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device2", "display_name 2") self.store.store_device("user_id", "device2", "display_name 2")
) )
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3") self.store.store_device("user_id2", "device3", "display_name 3")
) )
res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id")) res = self.get_success(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys())) self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"], res["device2"],
) )
@defer.inlineCallbacks
def test_count_devices_by_users(self): def test_count_devices_by_users(self):
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device1", "display_name 1") self.store.store_device("user_id", "device1", "display_name 1")
) )
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device2", "display_name 2") self.store.store_device("user_id", "device2", "display_name 2")
) )
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3") self.store.store_device("user_id2", "device3", "display_name 3")
) )
res = yield defer.ensureDeferred(self.store.count_devices_by_users()) res = self.get_success(self.store.count_devices_by_users())
self.assertEqual(0, res) self.assertEqual(0, res)
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"])) res = self.get_success(self.store.count_devices_by_users(["unknown"]))
self.assertEqual(0, res) self.assertEqual(0, res)
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"])) res = self.get_success(self.store.count_devices_by_users(["user_id"]))
self.assertEqual(2, res) self.assertEqual(2, res)
res = yield defer.ensureDeferred( res = self.get_success(
self.store.count_devices_by_users(["user_id", "user_id2"]) self.store.count_devices_by_users(["user_id", "user_id2"])
) )
self.assertEqual(3, res) self.assertEqual(3, res)
@defer.inlineCallbacks
def test_get_device_updates_by_remote(self): def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"] device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id # Add two device updates with a single stream_id
yield defer.ensureDeferred( self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
) )
# Get all device updates ever meant for this remote # Get all device updates ever meant for this remote
now_stream_id, device_updates = yield defer.ensureDeferred( now_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=100) self.store.get_device_updates_by_remote("somehost", -1, limit=100)
) )
@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
} }
self.assertEqual(received_device_ids, set(expected_device_ids)) self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks
def test_update_device(self): def test_update_device(self):
yield defer.ensureDeferred( self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1") self.store.store_device("user_id", "device_id", "display_name 1")
) )
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do a no-op first # do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) self.get_success(self.store.update_device("user_id", "device_id"))
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do the update # do the update
yield defer.ensureDeferred( self.get_success(
self.store.update_device( self.store.update_device(
"user_id", "device_id", new_display_name="display_name 2" "user_id", "device_id", new_display_name="display_name 2"
) )
) )
# check it worked # check it worked
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"]) self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
def test_update_unknown_device(self): def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm: exc = self.get_failure(
yield defer.ensureDeferred( self.store.update_device(
self.store.update_device( "user_id", "unknown_device_id", new_display_name="display_name 2"
"user_id", "unknown_device_id", new_display_name="display_name 2" ),
) synapse.api.errors.StoreError,
) )
self.assertEqual(404, cm.exception.code) self.assertEqual(404, exc.value.code)

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,28 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.types import RoomAlias, RoomID from synapse.types import RoomAlias, RoomID
from tests import unittest from tests.unittest import HomeserverTestCase
from tests.utils import setup_test_homeserver
class DirectoryStoreTestCase(unittest.TestCase): class DirectoryStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test") self.alias = RoomAlias.from_string("#my-room:test")
@defer.inlineCallbacks
def test_room_to_alias(self): def test_room_to_alias(self):
yield defer.ensureDeferred( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
) )
@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
["#my-room:test"], ["#my-room:test"],
( (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
yield defer.ensureDeferred(
self.store.get_aliases_for_room(self.room.to_string())
)
),
) )
@defer.inlineCallbacks
def test_alias_to_room(self): def test_alias_to_room(self):
yield defer.ensureDeferred( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
) )
@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]}, {"room_id": self.room.to_string(), "servers": ["test"]},
( (self.get_success(self.store.get_association_from_room_alias(self.alias))),
yield defer.ensureDeferred(
self.store.get_association_from_room_alias(self.alias)
)
),
) )
@defer.inlineCallbacks
def test_delete_alias(self): def test_delete_alias(self):
yield defer.ensureDeferred( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
) )
) )
room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias)) room_id = self.get_success(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id) self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone( self.assertIsNone(
( (self.get_success(self.store.get_association_from_room_alias(self.alias)))
yield defer.ensureDeferred(
self.store.get_association_from_room_alias(self.alias)
)
)
) )

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from tests.unittest import HomeserverTestCase
import tests.unittest
import tests.utils
class EndToEndKeyStoreTestCase(tests.unittest.TestCase): class EndToEndKeyStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_key_without_device_name(self): def test_key_without_device_name(self):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield defer.ensureDeferred(self.store.store_device("user", "device", None)) self.get_success(self.store.store_device("user", "device", None))
yield defer.ensureDeferred( self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
self.store.set_e2e_device_keys("user", "device", now, json)
)
res = yield defer.ensureDeferred( res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
) )
self.assertIn("user", res) self.assertIn("user", res)
@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev) self.assertDictContainsSubset(json, dev)
@defer.inlineCallbacks
def test_reupload_key(self): def test_reupload_key(self):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield defer.ensureDeferred(self.store.store_device("user", "device", None)) self.get_success(self.store.store_device("user", "device", None))
changed = yield defer.ensureDeferred( changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json) self.store.set_e2e_device_keys("user", "device", now, json)
) )
self.assertTrue(changed) self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing # If we try to upload the same key then we should be told nothing
# changed # changed
changed = yield defer.ensureDeferred( changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json) self.store.set_e2e_device_keys("user", "device", now, json)
) )
self.assertFalse(changed) self.assertFalse(changed)
@defer.inlineCallbacks
def test_get_key_with_device_name(self): def test_get_key_with_device_name(self):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield defer.ensureDeferred( self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
self.store.set_e2e_device_keys("user", "device", now, json) self.get_success(self.store.store_device("user", "device", "display_name"))
)
yield defer.ensureDeferred(
self.store.store_device("user", "device", "display_name")
)
res = yield defer.ensureDeferred( res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
) )
self.assertIn("user", res) self.assertIn("user", res)
@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
) )
@defer.inlineCallbacks
def test_multiple_devices(self): def test_multiple_devices(self):
now = 1470174257070 now = 1470174257070
yield defer.ensureDeferred(self.store.store_device("user1", "device1", None)) self.get_success(self.store.store_device("user1", "device1", None))
yield defer.ensureDeferred(self.store.store_device("user1", "device2", None)) self.get_success(self.store.store_device("user1", "device2", None))
yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) self.get_success(self.store.store_device("user2", "device1", None))
yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) self.get_success(self.store.store_device("user2", "device2", None))
yield defer.ensureDeferred( self.get_success(
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
) )
yield defer.ensureDeferred( self.get_success(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
) )
yield defer.ensureDeferred( self.get_success(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
) )
yield defer.ensureDeferred( self.get_success(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
) )
res = yield defer.ensureDeferred( res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api( self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2")) (("user1", "device1"), ("user2", "device2"))
) )

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,10 +15,7 @@
from mock import Mock from mock import Mock
from twisted.internet import defer from tests.unittest import HomeserverTestCase
import tests.unittest
import tests.utils
USER_ID = "@user:example.com" USER_ID = "@user:example.com"
@ -30,37 +27,31 @@ HIGHLIGHT = [
] ]
class EventPushActionsStoreTestCase(tests.unittest.TestCase): class EventPushActionsStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.persist_events_store = hs.get_datastores().persist_events self.persist_events_store = hs.get_datastores().persist_events
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self): def test_get_unread_push_actions_for_user_in_range_for_http(self):
yield defer.ensureDeferred( self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http( self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20 USER_ID, 0, 1000, 20
) )
) )
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self): def test_get_unread_push_actions_for_user_in_range_for_email(self):
yield defer.ensureDeferred( self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email( self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20 USER_ID, 0, 1000, 20
) )
) )
@defer.inlineCallbacks
def test_count_aggregation(self): def test_count_aggregation(self):
room_id = "!foo:example.com" room_id = "!foo:example.com"
user_id = "@user1235:example.com" user_id = "@user1235:example.com"
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count): def _assert_counts(noitf_count, highlight_count):
counts = yield defer.ensureDeferred( counts = self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
) )
@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
}, },
) )
@defer.inlineCallbacks
def _inject_actions(stream, action): def _inject_actions(stream, action):
event = Mock() event = Mock()
event.room_id = room_id event.room_id = room_id
@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
event.depth = stream event.depth = stream
yield defer.ensureDeferred( self.get_success(
self.store.add_push_actions_to_staging( self.store.add_push_actions_to_staging(
event.event_id, event.event_id,
{user_id: action}, {user_id: action},
False, False,
) )
) )
yield defer.ensureDeferred( self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", "",
self.persist_events_store._set_push_actions_for_event_and_users_txn, self.persist_events_store._set_push_actions_for_event_and_users_txn,
@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
) )
def _rotate(stream): def _rotate(stream):
return defer.ensureDeferred( self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream "", self.store._rotate_notifs_before_txn, stream
) )
) )
def _mark_read(stream, depth): def _mark_read(stream, depth):
return defer.ensureDeferred( self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", "",
self.store._remove_old_push_actions_before_txn, self.store._remove_old_push_actions_before_txn,
@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
) )
) )
yield _assert_counts(0, 0) _assert_counts(0, 0)
yield _inject_actions(1, PlAIN_NOTIF) _inject_actions(1, PlAIN_NOTIF)
yield _assert_counts(1, 0) _assert_counts(1, 0)
yield _rotate(2) _rotate(2)
yield _assert_counts(1, 0) _assert_counts(1, 0)
yield _inject_actions(3, PlAIN_NOTIF) _inject_actions(3, PlAIN_NOTIF)
yield _assert_counts(2, 0) _assert_counts(2, 0)
yield _rotate(4) _rotate(4)
yield _assert_counts(2, 0) _assert_counts(2, 0)
yield _inject_actions(5, PlAIN_NOTIF) _inject_actions(5, PlAIN_NOTIF)
yield _mark_read(3, 3) _mark_read(3, 3)
yield _assert_counts(1, 0) _assert_counts(1, 0)
yield _mark_read(5, 5) _mark_read(5, 5)
yield _assert_counts(0, 0) _assert_counts(0, 0)
yield _inject_actions(6, PlAIN_NOTIF) _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7) _rotate(7)
yield defer.ensureDeferred( self.get_success(
self.store.db_pool.simple_delete( self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc="" table="event_push_actions", keyvalues={"1": 1}, desc=""
) )
) )
yield _assert_counts(1, 0) _assert_counts(1, 0)
yield _mark_read(7, 7) _mark_read(7, 7)
yield _assert_counts(0, 0) _assert_counts(0, 0)
yield _inject_actions(8, HIGHLIGHT) _inject_actions(8, HIGHLIGHT)
yield _assert_counts(1, 1) _assert_counts(1, 1)
yield _rotate(9) _rotate(9)
yield _assert_counts(1, 1) _assert_counts(1, 1)
yield _rotate(10) _rotate(10)
yield _assert_counts(1, 1) _assert_counts(1, 1)
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self): def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts): def add_event(so, ts):
return defer.ensureDeferred( self.get_success(
self.store.db_pool.simple_insert( self.store.db_pool.simple_insert(
"events", "events",
{ {
@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
) )
# start with the base case where there are no events in the table # start with the base case where there are no events in the table
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.store.find_first_stream_ordering_after_ts(11)
)
self.assertEqual(r, 0) self.assertEqual(r, 0)
# now with one event # now with one event
yield add_event(2, 10) add_event(2, 10)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
self.store.find_first_stream_ordering_after_ts(9)
)
self.assertEqual(r, 2) self.assertEqual(r, 2)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
self.store.find_first_stream_ordering_after_ts(10)
)
self.assertEqual(r, 2) self.assertEqual(r, 2)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.store.find_first_stream_ordering_after_ts(11)
)
self.assertEqual(r, 3) self.assertEqual(r, 3)
# add a bunch of dummy events to the events table # add a bunch of dummy events to the events table
@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
(10, 130), (10, 130),
(20, 140), (20, 140),
): ):
yield add_event(stream_ordering, ts) add_event(stream_ordering, ts)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
self.store.find_first_stream_ordering_after_ts(110)
)
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r) self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5 # 4 and 5 are both after 120: we want 4 rather than 5
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
self.store.find_first_stream_ordering_after_ts(120)
)
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r) self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
self.store.find_first_stream_ordering_after_ts(129)
)
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r) self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event # check we can get the last event
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
self.store.find_first_stream_ordering_after_ts(140)
)
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r) self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end # off the end
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
self.store.find_first_stream_ordering_after_ts(160)
)
self.assertEqual(r, 21) self.assertEqual(r, 21)
# check we can find an event at ordering zero # check we can find an event at ordering zero
yield add_event(0, 5) add_event(0, 5)
r = yield defer.ensureDeferred( r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
self.store.find_first_stream_ordering_after_ts(1)
)
self.assertEqual(r, 0) self.assertEqual(r, 0)

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,59 +13,50 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver
class ProfileStoreTestCase(unittest.TestCase): class ProfileStoreTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test") self.u_frank = UserID.from_string("@frank:test")
@defer.inlineCallbacks
def test_displayname(self): def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) self.get_success(self.store.create_profile(self.u_frank.localpart))
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank") self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
) )
self.assertEquals( self.assertEquals(
"Frank", "Frank",
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart) self.store.get_profile_displayname(self.u_frank.localpart)
) )
), ),
) )
# test set to None # test set to None
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, None) self.store.set_profile_displayname(self.u_frank.localpart, None)
) )
self.assertIsNone( self.assertIsNone(
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart) self.store.get_profile_displayname(self.u_frank.localpart)
) )
) )
) )
@defer.inlineCallbacks
def test_avatar_url(self): def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) self.get_success(self.store.create_profile(self.u_frank.localpart))
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_avatar_url( self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here" self.u_frank.localpart, "http://my.site/here"
) )
@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
"http://my.site/here", "http://my.site/here",
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart) self.store.get_profile_avatar_url(self.u_frank.localpart)
) )
), ),
) )
# test set to None # test set to None
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_avatar_url(self.u_frank.localpart, None) self.store.set_profile_avatar_url(self.u_frank.localpart, None)
) )
self.assertIsNone( self.assertIsNone(
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart) self.store.get_profile_avatar_url(self.u_frank.localpart)
) )
) )

View file

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,8 +15,6 @@
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
@ -230,10 +227,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._base_builder = base_builder self._base_builder = base_builder
self._event_id = event_id self._event_id = event_id
@defer.inlineCallbacks async def build(self, prev_event_ids, auth_event_ids):
def build(self, prev_event_ids, auth_event_ids): built_event = await self._base_builder.build(
built_event = yield defer.ensureDeferred( prev_event_ids, auth_event_ids
self._base_builder.build(prev_event_ids, auth_event_ids)
) )
built_event._event_id = self._event_id built_event._event_id = self._event_id

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,21 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError from synapse.api.errors import ThreepidValidationError
from tests import unittest from tests.unittest import HomeserverTestCase
from tests.utils import setup_test_homeserver
class RegistrationStoreTestCase(unittest.TestCase): class RegistrationStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_id = "@my-user:test" self.user_id = "@my-user:test"
@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.pwhash = "{xx1}123456789" self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg" self.device_id = "akgjhdjklgshg"
@defer.inlineCallbacks
def test_register(self): def test_register(self):
yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals( self.assertEquals(
{ {
@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None, "consent_version": None,
"consent_server_notice_sent": None, "consent_server_notice_sent": None,
"appservice_id": None, "appservice_id": None,
"creation_ts": 1000, "creation_ts": 0,
"user_type": None, "user_type": None,
"deactivated": 0, "deactivated": 0,
"shadow_banned": 0, "shadow_banned": 0,
}, },
(yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), (self.get_success(self.store.get_user_by_id(self.user_id))),
) )
@defer.inlineCallbacks
def test_add_tokens(self): def test_add_tokens(self):
yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred( self.get_success(
self.store.add_access_token_to_user( self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
) )
) )
result = yield defer.ensureDeferred( result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.store.get_user_by_access_token(self.tokens[1])
)
self.assertEqual(result.user_id, self.user_id) self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id) self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id) self.assertIsNotNone(result.token_id)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self): def test_user_delete_access_tokens(self):
# add some tokens # add some tokens
yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred( self.get_success(
self.store.add_access_token_to_user( self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
) )
) )
yield defer.ensureDeferred( self.get_success(
self.store.add_access_token_to_user( self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
) )
) )
# now delete some # now delete some
yield defer.ensureDeferred( self.get_success(
self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id) self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
) )
# check they were deleted # check they were deleted
user = yield defer.ensureDeferred( user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.store.get_user_by_access_token(self.tokens[1])
)
self.assertIsNone(user, "access token was not deleted by device_id") self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted # check the one not associated with the device was not deleted
user = yield defer.ensureDeferred( user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.store.get_user_by_access_token(self.tokens[0])
)
self.assertEqual(self.user_id, user.user_id) self.assertEqual(self.user_id, user.user_id)
# now delete the rest # now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) self.get_success(self.store.user_delete_access_tokens(self.user_id))
user = yield defer.ensureDeferred( user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.store.get_user_by_access_token(self.tokens[0])
)
self.assertIsNone(user, "access token was not deleted without device_id") self.assertIsNone(user, "access token was not deleted without device_id")
@defer.inlineCallbacks
def test_is_support_user(self): def test_is_support_user(self):
TEST_USER = "@test:test" TEST_USER = "@test:test"
SUPPORT_USER = "@support:test" SUPPORT_USER = "@support:test"
res = yield defer.ensureDeferred(self.store.is_support_user(None)) res = self.get_success(self.store.is_support_user(None))
self.assertFalse(res) self.assertFalse(res)
yield defer.ensureDeferred( self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None) self.store.register_user(user_id=TEST_USER, password_hash=None)
) )
res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER)) res = self.get_success(self.store.is_support_user(TEST_USER))
self.assertFalse(res) self.assertFalse(res)
yield defer.ensureDeferred( self.get_success(
self.store.register_user( self.store.register_user(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
) )
) )
res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER)) res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res) self.assertTrue(res)
@defer.inlineCallbacks
def test_3pid_inhibit_invalid_validation_session_error(self): def test_3pid_inhibit_invalid_validation_session_error(self):
"""Tests that enabling the configuration option to inhibit 3PID errors on """Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID. /requestToken also inhibits validation errors caused by an unknown session ID.
@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
# Check that, with the config setting set to false (the default value), a # Check that, with the config setting set to false (the default value), a
# validation error is caused by the unknown session ID. # validation error is caused by the unknown session ID.
try: e = self.get_failure(
yield defer.ensureDeferred( self.store.validate_threepid_session(
self.store.validate_threepid_session( "fake_sid",
"fake_sid", "fake_client_secret",
"fake_client_secret", "fake_token",
"fake_token", 0,
0, ),
) ThreepidValidationError,
) )
except ThreepidValidationError as e: self.assertEquals(e.value.msg, "Unknown session_id", e)
self.assertEquals(e.msg, "Unknown session_id", e)
# Set the config setting to true. # Set the config setting to true.
self.store._ignore_unknown_session_error = True self.store._ignore_unknown_session_error = True
# Check that now the validation error is caused by the token not matching. # Check that now the validation error is caused by the token not matching.
try: e = self.get_failure(
yield defer.ensureDeferred( self.store.validate_threepid_session(
self.store.validate_threepid_session( "fake_sid",
"fake_sid", "fake_client_secret",
"fake_client_secret", "fake_token",
"fake_token", 0,
0, ),
) ThreepidValidationError,
) )
except ThreepidValidationError as e: self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
self.assertEquals(e.msg, "Validation token not found or has expired", e)

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,22 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID from synapse.types import RoomAlias, RoomID, UserID
from tests import unittest from tests.unittest import HomeserverTestCase
from tests.utils import setup_test_homeserver
class RoomStoreTestCase(unittest.TestCase): class RoomStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
# We can't test RoomStore on its own without the DirectoryStore, for # We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table # management of the 'room_aliases' table
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test") self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test") self.u_creator = UserID.from_string("@creator:test")
yield defer.ensureDeferred( self.get_success(
self.store.store_room( self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(), room_creator_user_id=self.u_creator.to_string(),
@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
) )
) )
@defer.inlineCallbacks
def test_get_room(self): def test_get_room(self):
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"is_public": True, "is_public": True,
}, },
(yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), (self.get_success(self.store.get_room(self.room.to_string()))),
) )
@defer.inlineCallbacks
def test_get_room_unknown_room(self): def test_get_room_unknown_room(self):
self.assertIsNone( self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
(yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
)
@defer.inlineCallbacks
def test_get_room_with_stats(self): def test_get_room_with_stats(self):
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"public": True, "public": True,
}, },
( (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
yield defer.ensureDeferred(
self.store.get_room_with_stats(self.room.to_string())
)
),
) )
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self): def test_get_room_with_stats_unknown_room(self):
self.assertIsNone( self.assertIsNone(
( (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
yield defer.ensureDeferred(
self.store.get_room_with_stats("!uknown:test")
)
),
) )
class RoomEventsStoreTestCase(unittest.TestCase): class RoomEventsStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = setup_test_homeserver(self.addCleanup)
# Room events need the full datastore, for persist_event() and # Room events need the full datastore, for persist_event() and
# get_room_state() # get_room_state()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
yield defer.ensureDeferred( self.get_success(
self.store.store_room( self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id="@creator:text", room_creator_user_id="@creator:text",
@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
) )
) )
@defer.inlineCallbacks
def inject_room_event(self, **kwargs): def inject_room_event(self, **kwargs):
yield defer.ensureDeferred( self.get_success(
self.storage.persistence.persist_event( self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
) )
) )
@defer.inlineCallbacks
def STALE_test_room_name(self): def STALE_test_room_name(self):
name = "A-Room-Name" name = "A-Room-Name"
yield self.inject_room_event( self.inject_room_event(
etype=EventTypes.Name, name=name, content={"name": name}, depth=1 etype=EventTypes.Name, name=name, content={"name": name}, depth=1
) )
state = yield defer.ensureDeferred( state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string()) self.store.get_current_state(room_id=self.room.to_string())
) )
@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
state[0], state[0],
) )
@defer.inlineCallbacks
def STALE_test_room_topic(self): def STALE_test_room_topic(self):
topic = "A place for things" topic = "A place for things"
yield self.inject_room_event( self.inject_room_event(
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
) )
state = yield defer.ensureDeferred( state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string()) self.store.get_current_state(room_id=self.room.to_string())
) )

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,24 +15,18 @@
import logging import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
import tests.unittest from tests.unittest import HomeserverTestCase
import tests.utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StateStoreTestCase(tests.unittest.TestCase): class StateStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_datastore = self.storage.state.stores.state self.state_datastore = self.storage.state.stores.state
@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test") self.room = RoomID.from_string("!abc123:test")
yield defer.ensureDeferred( self.get_success(
self.store.store_room( self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id="@creator:text", room_creator_user_id="@creator:text",
@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
) )
@defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content): def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,
@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
}, },
) )
event, context = yield defer.ensureDeferred( event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
yield defer.ensureDeferred( self.get_success(self.storage.persistence.persist_event(event, context))
self.storage.persistence.persist_event(event, context)
)
return event return event
@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2)) self.assertEqual(len(s1), len(s2))
@defer.inlineCallbacks
def test_get_state_groups_ids(self): def test_get_state_groups_ids(self):
e1 = yield self.inject_state_event( e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
self.room, self.u_alice, EventTypes.Create, "", {} e2 = self.inject_state_event(
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield defer.ensureDeferred( state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
) )
@defer.inlineCallbacks
def test_get_state_groups(self): def test_get_state_groups(self):
e1 = yield self.inject_state_event( e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
self.room, self.u_alice, EventTypes.Create, "", {} e2 = self.inject_state_event(
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield defer.ensureDeferred( state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id]) self.storage.state.get_state_groups(self.room, [e2.event_id])
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
@defer.inlineCallbacks
def test_get_state_for_event(self): def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever # this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room. # forward extremities are currently in the DB for this room.
e1 = yield self.inject_state_event( e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
self.room, self.u_alice, EventTypes.Create, "", {} e2 = self.inject_state_event(
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
e3 = yield self.inject_state_event( e3 = self.inject_state_event(
self.room, self.room,
self.u_alice, self.u_alice,
EventTypes.Member, EventTypes.Member,
self.u_alice.to_string(), self.u_alice.to_string(),
{"membership": Membership.JOIN}, {"membership": Membership.JOIN},
) )
e4 = yield self.inject_state_event( e4 = self.inject_state_event(
self.room, self.room,
self.u_bob, self.u_bob,
EventTypes.Member, EventTypes.Member,
self.u_bob.to_string(), self.u_bob.to_string(),
{"membership": Membership.JOIN}, {"membership": Membership.JOIN},
) )
e5 = yield self.inject_state_event( e5 = self.inject_state_event(
self.room, self.room,
self.u_bob, self.u_bob,
EventTypes.Member, EventTypes.Member,
@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we get the full state as of the final event # check we get the full state as of the final event
state = yield defer.ensureDeferred( state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
self.storage.state.get_state_for_event(e5.event_id)
)
self.assertIsNotNone(e4) self.assertIsNotNone(e4)
@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we can filter to the m.room.name event (with a '' state key) # check we can filter to the m.room.name event (with a '' state key)
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
) )
@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key) # check we can filter to the m.room.name event (with a wildcard None state key)
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
) )
@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key) # check we can grab the m.room.member events (with a wildcard None state key)
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
) )
@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the # check we can grab a specific room member without filtering out the
# other event types # other event types
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check that we can grab everything except members # check that we can grab everything except members
state = yield defer.ensureDeferred( state = self.get_success(
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
####################################################### #######################################################
room_id = self.room.to_string() room_id = self.room.to_string()
group_ids = yield defer.ensureDeferred( group_ids = self.get_success(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
) )
group = list(group_ids.keys())[0] group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members # test _get_state_for_group_using_cache correctly filters out members
# with types=[] # with types=[]
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with wildcard types # with wildcard types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members # test _get_state_for_group_using_cache correctly filters out members
# with types=[] # with types=[]
room_id = self.room.to_string() room_id = self.room.to_string()
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string() room_id = self.room.to_string()
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# wildcard types # wildcard types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict) self.assertDictEqual({}, state_dict)
( (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from tests.unittest import HomeserverTestCase, override_config
from tests import unittest
from tests.utils import setup_test_homeserver
ALICE = "@alice:a" ALICE = "@alice:a"
BOB = "@bob:b" BOB = "@bob:b"
@ -25,73 +22,52 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a" BELA = "@somenickname:a"
class UserDirectoryStoreTestCase(unittest.TestCase): class UserDirectoryStoreTestCase(HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self): self.store = hs.get_datastore()
self.hs = yield setup_test_homeserver(self.addCleanup)
self.store = self.hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares # alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice. # a homeserver with alice.
yield defer.ensureDeferred( self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
self.store.update_profile_in_user_dir(ALICE, "alice", None) self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
) self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
yield defer.ensureDeferred( self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
self.store.update_profile_in_user_dir(BOB, "bob", None) self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
)
yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
)
yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(BELA, "Bela", None)
)
yield defer.ensureDeferred(
self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
)
@defer.inlineCallbacks
def test_search_user_dir(self): def test_search_user_dir(self):
# normally when alice searches the directory she should just find # normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her. # bob because bobby doesn't share a room with her.
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"]) self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"])) self.assertEqual(1, len(r["results"]))
self.assertDictEqual( self.assertDictEqual(
r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None} r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
) )
@defer.inlineCallbacks @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_all_users(self): def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
try: self.assertFalse(r["limited"])
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) self.assertEqual(2, len(r["results"]))
self.assertFalse(r["limited"]) self.assertDictEqual(
self.assertEqual(2, len(r["results"])) r["results"][0],
self.assertDictEqual( {"user_id": BOB, "display_name": "bob", "avatar_url": None},
r["results"][0], )
{"user_id": BOB, "display_name": "bob", "avatar_url": None}, self.assertDictEqual(
) r["results"][1],
self.assertDictEqual( {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
r["results"][1], )
{"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
)
finally:
self.hs.config.user_directory_search_all_users = False
@defer.inlineCallbacks @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self): def test_search_user_dir_stop_words(self):
"""Tests that a user can look up another user by searching for the start if its """Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would display name even if that name happens to be a common English word that would
usually be ignored in full text searches. usually be ignored in full text searches.
""" """
self.hs.config.user_directory_search_all_users = True r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
try: self.assertFalse(r["limited"])
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10)) self.assertEqual(1, len(r["results"]))
self.assertFalse(r["limited"]) self.assertDictEqual(
self.assertEqual(1, len(r["results"])) r["results"][0],
self.assertDictEqual( {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
r["results"][0], )
{"user_id": BELA, "display_name": "Bela", "avatar_url": None},
)
finally:
self.hs.config.user_directory_search_all_users = False

View file

@ -471,7 +471,7 @@ class HomeserverTestCase(TestCase):
kwargs["config"] = config_obj kwargs["config"] = config_obj
async def run_bg_updates(): async def run_bg_updates():
with LoggingContext("run_bg_updates", request="run_bg_updates-1"): with LoggingContext("run_bg_updates"):
while not await stor.db_pool.updates.has_completed_background_updates(): while not await stor.db_pool.updates.has_completed_background_updates():
await stor.db_pool.updates.do_next_background_update(1) await stor.db_pool.updates.do_next_background_update(1)

View file

@ -661,14 +661,13 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1") @descriptors.cachedList("fn", "args1")
async def list_fn(self, args1, arg2): async def list_fn(self, args1, arg2):
assert current_context().request == "c1" assert current_context().name == "c1"
# we want this to behave like an asynchronous function # we want this to behave like an asynchronous function
await run_on_reactor() await run_on_reactor()
assert current_context().request == "c1" assert current_context().name == "c1"
return self.mock(args1, arg2) return self.mock(args1, arg2)
with LoggingContext() as c1: with LoggingContext("c1") as c1:
c1.request = "c1"
obj = Cls() obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"} obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2) d1 = obj.list_fn([10, 20], 2)

View file

@ -17,11 +17,10 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase): class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value): def _check_test_key(self, value):
self.assertEquals(current_context().request, value) self.assertEquals(current_context().name, value)
def test_with_context(self): def test_with_context(self):
with LoggingContext() as context_one: with LoggingContext("test"):
context_one.request = "test"
self._check_test_key("test") self._check_test_key("test")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -30,15 +29,13 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def competing_callback(): def competing_callback():
with LoggingContext() as competing_context: with LoggingContext("competing"):
competing_context.request = "competing"
yield clock.sleep(0) yield clock.sleep(0)
self._check_test_key("competing") self._check_test_key("competing")
reactor.callLater(0, competing_callback) reactor.callLater(0, competing_callback)
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
yield clock.sleep(0) yield clock.sleep(0)
self._check_test_key("one") self._check_test_key("one")
@ -47,9 +44,7 @@ class LoggingContextTestCase(unittest.TestCase):
callback_completed = [False] callback_completed = [False]
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
# fire off function, but don't wait on it. # fire off function, but don't wait on it.
d2 = run_in_background(function) d2 = run_in_background(function)
@ -133,9 +128,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context() sentinel_context = current_context()
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function()) d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)
@ -149,9 +142,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_make_deferred_yieldable_with_chained_deferreds(self): def test_make_deferred_yieldable_with_chained_deferreds(self):
sentinel_context = current_context() sentinel_context = current_context()
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
d1 = make_deferred_yieldable(_chained_deferred_function()) d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)
@ -166,9 +157,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""Check that make_deferred_yieldable does the right thing when its """Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred""" argument isn't actually a deferred"""
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
d1 = make_deferred_yieldable("bum") d1 = make_deferred_yieldable("bum")
self._check_test_key("one") self._check_test_key("one")
@ -177,9 +166,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one") self._check_test_key("one")
def test_nested_logging_context(self): def test_nested_logging_context(self):
with LoggingContext(request="foo"): with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar") nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar") self.assertEqual(nested_context.name, "foo-bar")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self): def test_make_deferred_yieldable_with_await(self):
@ -193,9 +182,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context() sentinel_context = current_context()
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function()) d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)