Merge branch 'release-v1.13.0' into erikj/faster_device_lists_fetch

This commit is contained in:
Richard van der Hoff 2020-05-05 18:14:00 +01:00
commit 13dd458b8d
66 changed files with 847 additions and 691 deletions

View file

@ -30,23 +30,24 @@ recursive-include synapse/static *.gif
recursive-include synapse/static *.html
recursive-include synapse/static *.js
exclude Dockerfile
exclude .codecov.yml
exclude .coveragerc
exclude .dockerignore
exclude test_postgresql.sh
exclude .editorconfig
exclude Dockerfile
exclude mypy.ini
exclude sytest-blacklist
exclude test_postgresql.sh
include pyproject.toml
recursive-include changelog.d *
prune .buildkite
prune .circleci
prune .codecov.yml
prune .coveragerc
prune .github
prune contrib
prune debian
prune demo/etc
prune docker
prune mypy.ini
prune snap
prune stubs

View file

@ -75,6 +75,37 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.13.0
====================
Incorrect database migration in old synapse versions
----------------------------------------------------
A bug was introduced in Synapse 1.4.0 which could cause the room directory to
be incomplete or empty if Synapse was upgraded directly from v1.2.1 or earlier,
to versions between v1.4.0 and v1.12.x.
This will *not* be a problem for Synapse installations which were:
* created at v1.4.0 or later,
* upgraded via v1.3.x, or
* upgraded straight from v1.2.1 or earlier to v1.13.0 or later.
If completeness of the room directory is a concern, installations which are
affected can be repaired as follows:
1. Run the following sql from a `psql` or `sqlite3` console:
.. code:: sql
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_rooms', '{}', 'current_state_events_membership');
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
2. Restart synapse.
Upgrading to v1.12.0
====================

1
changelog.d/7172.misc Normal file
View file

@ -0,0 +1 @@
Use `stream.current_token()` and remove `stream_positions()`.

1
changelog.d/7363.misc Normal file
View file

@ -0,0 +1 @@
Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await.

1
changelog.d/7368.bugfix Normal file
View file

@ -0,0 +1 @@
Improve error responses when accessing remote public room lists.

1
changelog.d/7369.misc Normal file
View file

@ -0,0 +1 @@
Thread through instance name to replication client.

1
changelog.d/7387.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug which would cause the room durectory to be incorrectly populated if Synapse was upgraded directly from v1.2.1 or earlier to v1.4.0 or later. Note that this fix does not apply retrospectively; see the [upgrade notes](UPGRADE.rst#upgrading-to-v1130) for more information.

1
changelog.d/7393.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug in `EventContext.deserialize`.

1
changelog.d/7394.misc Normal file
View file

@ -0,0 +1 @@
Convert synapse.server_notices to async/await.

1
changelog.d/7395.misc Normal file
View file

@ -0,0 +1 @@
Convert synapse.notifier to async/await.

1
changelog.d/7401.feature Normal file
View file

@ -0,0 +1 @@
Add support for running replication over Redis when using workers.

1
changelog.d/7404.misc Normal file
View file

@ -0,0 +1 @@
Fix issues with the Python package manifest.

1
changelog.d/7408.misc Normal file
View file

@ -0,0 +1 @@
Clean up some LoggingContext code.

View file

@ -22,7 +22,10 @@ class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
class SubscriberProtocol:
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
def connectionMade(self): ...
def connectionLost(self, reason): ...
def lazyConnection(
host: str = ...,

View file

@ -537,8 +537,7 @@ class Auth(object):
return defer.succeed(auth_ids)
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id: str, user: UserID):
async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
@ -547,17 +546,17 @@ class Auth(object):
user
"""
is_admin = yield self.is_server_admin(user)
is_admin = await self.is_server_admin(user)
if is_admin:
return True
user_id = user.to_string()
yield self.check_user_in_room(room_id, user_id)
await self.check_user_in_room(room_id, user_id)
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events
power_level_event = yield self.state.get_current_state(
power_level_event = await self.state.get_current_state(
room_id, EventTypes.PowerLevels, ""
)

View file

@ -413,12 +413,6 @@ class GenericWorkerTyping(object):
# map room IDs to sets of users currently typing
self._room_typing = {}
def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial}
def process_replication_rows(self, token, rows):
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
@ -652,20 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
else:
self.send_handler = None
async def on_rdata(self, stream_name, token, rows):
await super(GenericWorkerReplicationHandler, self).on_rdata(
stream_name, token, rows
)
await self.process_and_notify(stream_name, token, rows)
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
await self._process_and_notify(stream_name, instance_name, token, rows)
def get_streams_to_replicate(self):
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
args.update(self.typing_handler.stream_positions())
if self.send_handler:
args.update(self.send_handler.stream_positions())
return args
async def process_and_notify(self, stream_name, token, rows):
async def _process_and_notify(self, stream_name, instance_name, token, rows):
try:
if self.send_handler:
await self.send_handler.process_replication_rows(
@ -799,9 +784,6 @@ class FederationSenderHandler(object):
def wake_destination(self, server: str):
self.federation_sender.wake_destination(server)
def stream_positions(self):
return {"federation": self.federation_position}
async def process_replication_rows(self, stream_name, token, rows):
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.

View file

@ -322,11 +322,14 @@ class _AsyncEventContextImpl(EventContext):
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
self.state_group
)
if self._prev_state_id and self._event_state_key is not None:
if self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
key = (self._event_type, self._event_state_key)
self._prev_state_ids[key] = self._prev_state_id
if self._prev_state_id:
self._prev_state_ids[key] = self._prev_state_id
else:
self._prev_state_ids.pop(key, None)
else:
self._prev_state_ids = self._current_state_ids

View file

@ -883,18 +883,37 @@ class FederationClient(FederationBase):
def get_public_rooms(
self,
destination,
limit=None,
since_token=None,
search_filter=None,
include_all_networks=False,
third_party_instance_id=None,
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
):
if destination == self.server_name:
return
"""Get the list of public rooms from a remote homeserver
Args:
remote_server: The name of the remote server
limit: Maximum amount of rooms to return
since_token: Used for result pagination
search_filter: A filter dictionary to send the remote homeserver
and filter the result set
include_all_networks: Whether to include results from all third party instances
third_party_instance_id: Whether to only include results from a specific third
party instance
Returns:
Deferred[Dict[str, Any]]: The response from the remote server, or None if
`remote_server` is the same as the local server_name
Raises:
HttpResponseException: There was an exception returned from the remote server
SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
requests over federation
"""
return self.transport_layer.get_public_rooms(
destination,
remote_server,
limit,
since_token,
search_filter,
@ -957,14 +976,13 @@ class FederationClient(FederationBase):
return signed_events
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
async def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
try:
yield self.transport_layer.exchange_third_party_invite(
await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
return None

View file

@ -15,13 +15,14 @@
# limitations under the License.
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional
from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
@ -326,18 +327,25 @@ class TransportLayerClient(object):
@log_function
def get_public_rooms(
self,
remote_server,
limit,
since_token,
search_filter=None,
include_all_networks=False,
third_party_instance_id=None,
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
):
"""Get the list of public rooms from a remote homeserver
See synapse.federation.federation_client.FederationClient.get_public_rooms for
more information.
"""
if search_filter:
# this uses MSC2197 (Search Filtering over Federation)
path = _create_v1_path("/publicRooms")
data = {"include_all_networks": "true" if include_all_networks else "false"}
data = {
"include_all_networks": "true" if include_all_networks else "false"
} # type: Dict[str, Any]
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
@ -347,9 +355,19 @@ class TransportLayerClient(object):
data["filter"] = search_filter
response = yield self.client.post_json(
destination=remote_server, path=path, data=data, ignore_backoff=True
)
try:
response = yield self.client.post_json(
destination=remote_server, path=path, data=data, ignore_backoff=True
)
except HttpResponseException as e:
if e.code == 403:
raise SynapseError(
403,
"You are not allowed to view the public rooms list of %s"
% (remote_server,),
errcode=Codes.FORBIDDEN,
)
raise
else:
path = _create_v1_path("/publicRooms")
@ -363,9 +381,19 @@ class TransportLayerClient(object):
if since_token:
args["since"] = [since_token]
response = yield self.client.get_json(
destination=remote_server, path=path, args=args, ignore_backoff=True
)
try:
response = yield self.client.get_json(
destination=remote_server, path=path, args=args, ignore_backoff=True
)
except HttpResponseException as e:
if e.code == 403:
raise SynapseError(
403,
"You are not allowed to view the public rooms list of %s"
% (remote_server,),
errcode=Codes.FORBIDDEN,
)
raise
return response

View file

@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise NotImplementedError()
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
"""Remove a user from the group; either a user is leaving or an admin
kicked them.
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_kick = False
if requester_user_id != user_id:
is_admin = yield self.store.is_user_admin_in_group(
is_admin = await self.store.is_user_admin_in_group(
group_id, requester_user_id
)
if not is_admin:
@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
is_kick = True
yield self.store.remove_user_from_group(group_id, user_id)
await self.store.remove_user_from_group(group_id, user_id)
if is_kick:
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {})
await groups_local.user_removed_from_group(group_id, user_id, {})
else:
yield self.transport_client.remove_user_from_group_notification(
await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {}
)
if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id)
await self.store.maybe_delete_remote_profile_cache(user_id)
# Delete group if the last user has left
users = yield self.store.get_users_in_group(group_id, include_private=True)
users = await self.store.get_users_in_group(group_id, include_private=True)
if not users:
yield self.store.delete_group(group_id)
await self.store.delete_group(group_id)
return {}
@defer.inlineCallbacks
def create_group(self, group_id, requester_user_id, content):
group = yield self.check_group_is_ours(group_id, requester_user_id)
async def create_group(self, group_id, requester_user_id, content):
group = await self.check_group_is_ours(group_id, requester_user_id)
logger.info("Attempting to create group with ID: %r", group_id)
@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if group:
raise SynapseError(400, "Group already exists")
is_admin = yield self.auth.is_server_admin(
is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id)
)
if not is_admin:
@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
long_description = profile.get("long_description")
user_profile = content.get("user_profile", {})
yield self.store.create_group(
await self.store.create_group(
group_id,
requester_user_id,
name=name,
@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if not self.hs.is_mine_id(requester_user_id):
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
await self.attestations.verify_attestation(
remote_attestation, user_id=requester_user_id, group_id=group_id
)
@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
local_attestation = None
remote_attestation = None
yield self.store.add_user_to_group(
await self.store.add_user_to_group(
group_id,
requester_user_id,
is_admin=True,
@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
)
if not self.hs.is_mine_id(requester_user_id):
yield self.store.add_remote_profile_cache(
await self.store.add_remote_profile_cache(
requester_user_id,
displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"),
@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"group_id": group_id}
@defer.inlineCallbacks
def delete_group(self, group_id, requester_user_id):
async def delete_group(self, group_id, requester_user_id):
"""Deletes a group, kicking out all current members.
Only group admins or server admins can call this request
@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
Deferred
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
# Only server admins or group admins can delete groups.
is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id)
is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id)
if not is_admin:
is_admin = yield self.auth.is_server_admin(
is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id)
)
@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise SynapseError(403, "User is not an admin")
# Before deleting the group lets kick everyone out of it
users = yield self.store.get_users_in_group(group_id, include_private=True)
users = await self.store.get_users_in_group(group_id, include_private=True)
@defer.inlineCallbacks
def _kick_user_from_group(user_id):
async def _kick_user_from_group(user_id):
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {})
await groups_local.user_removed_from_group(group_id, user_id, {})
else:
yield self.transport_client.remove_user_from_group_notification(
await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {}
)
yield self.store.maybe_delete_remote_profile_cache(user_id)
await self.store.maybe_delete_remote_profile_cache(user_id)
# We kick users out in the order of:
# 1. Non-admins
@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
else:
non_admins.append(u["user_id"])
yield concurrently_execute(_kick_user_from_group, non_admins, 10)
yield concurrently_execute(_kick_user_from_group, admins, 10)
yield _kick_user_from_group(requester_user_id)
await concurrently_execute(_kick_user_from_group, non_admins, 10)
await concurrently_execute(_kick_user_from_group, admins, 10)
await _kick_user_from_group(requester_user_id)
yield self.store.delete_group(group_id)
await self.store.delete_group(group_id)
def _parse_join_policy_from_contents(content):

View file

@ -126,30 +126,28 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now))
)
@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, context=None):
async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
current_state_ids = yield context.get_current_state_ids()
current_state = yield self.store.get_events(
current_state_ids = await context.get_current_state_ids()
current_state = await self.store.get_events(
list(current_state_ids.values())
)
else:
current_state = yield self.state_handler.get_current_state(
current_state = await self.state_handler.get_current_state(
event.room_id
)
current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
await self.kick_guest_users(current_state)
@defer.inlineCallbacks
def kick_guest_users(self, current_state):
async def kick_guest_users(self, current_state):
for member_event in current_state:
try:
if member_event.type != EventTypes.Member:
@ -180,7 +178,7 @@ class BaseHandler(object):
# homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler()
yield handler.update_membership(
await handler.update_membership(
requester,
target_user,
member_event.room_id,

View file

@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler):
room_alias, room_id, servers, creator=creator
)
@defer.inlineCallbacks
def create_association(
async def create_association(
self,
requester: Requester,
room_alias: RoomAlias,
@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler):
else:
# Server admins are not subject to the same constraints as normal
# users when creating an alias (e.g. being in the room).
is_admin = yield self.auth.is_server_admin(requester.user)
is_admin = await self.auth.is_server_admin(requester.user)
if (self.require_membership and check_membership) and not is_admin:
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
rooms_for_user = await self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user:
raise AuthError(
403, "You must be in the room to create an alias for it"
@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to create alias")
can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
can_create = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
400,
@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
yield self._create_association(room_alias, room_id, servers, creator=user_id)
await self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks
def delete_association(self, requester: Requester, room_alias: RoomAlias):
async def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler):
user_id = requester.user.to_string()
try:
can_delete = yield self._user_can_delete_alias(room_alias, user_id)
can_delete = await self._user_can_delete_alias(room_alias, user_id)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown room alias")
@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.")
can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
400,
@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
room_id = yield self._delete_association(room_alias)
room_id = await self._delete_association(room_alias)
try:
yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
@defer.inlineCallbacks
def _update_canonical_alias(
async def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
"""
alias_event = yield self.state.get_current_state(
alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
)
@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler):
del content["alt_aliases"]
if send_update:
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler):
# either no interested services, or no service with an exclusive lock
return defer.succeed(True)
@defer.inlineCallbacks
def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias.
One of the following must be true:
@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler):
for the current room.
"""
creator = yield self.store.get_room_alias_creator(alias.to_string())
creator = await self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id:
return True
# Resolve the alias to the corresponding room.
room_mapping = yield self.get_association(alias)
room_mapping = await self.get_association(alias)
room_id = room_mapping["room_id"]
if not room_id:
return False
res = yield self.auth.check_can_change_room_list(
res = await self.auth.check_can_change_room_list(
room_id, UserID.from_string(user_id)
)
return res
@defer.inlineCallbacks
def edit_published_room_list(
async def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str
):
"""Edit the entry of the room in the published room list.
@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler):
403, "This user is not permitted to publish rooms to the room list"
)
room = yield self.store.get_room(room_id)
room = await self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Unknown room")
can_change_room_list = yield self.auth.check_can_change_room_list(
can_change_room_list = await self.auth.check_can_change_room_list(
room_id, requester.user
)
if not can_change_room_list:
@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler):
making_public = visibility == "public"
if making_public:
room_aliases = yield self.store.get_aliases_for_room(room_id)
canonical_alias = yield self.store.get_canonical_alias_for_room(room_id)
room_aliases = await self.store.get_aliases_for_room(room_id)
canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
if canonical_alias:
room_aliases.append(canonical_alias)
@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
yield self.store.set_room_is_public(room_id, making_public)
await self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
def edit_published_appservice_room_list(

View file

@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler):
"missing": [e.event_id for e in missing_locals],
}
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(
async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed
):
third_party_invite = {"signed": signed}
@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler):
"state_key": target_user_id,
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
room_version = yield self.store.get_room_version_id(room_id)
if await self.auth.check_host_in_room(room_id, self.hs.hostname):
room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
event_allowed = yield self.third_party_event_rules.check_event_allowed(
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler):
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
event, context = yield self.add_display_name_to_third_party_invite(
event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
yield self.auth.check_from_context(room_version, event, context)
await self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
await self._check_signature(event, context)
# We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
await member_handler.send_membership_event(None, event, context)
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
yield self.federation_client.forward_third_party_invite(
await self.federation_client.forward_third_party_invite(
destinations, room_id, event_dict
)

View file

@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
@defer.inlineCallbacks
def create_group(self, group_id, user_id, content):
async def create_group(self, group_id, user_id, content):
"""Create a group
"""
logger.info("Asking to create group with ID: %r", group_id)
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.create_group(
res = await self.groups_server_handler.create_group(
group_id, user_id, content
)
local_attestation = None
@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
content["user_profile"] = await self.profile_handler.get_profile(user_id)
try:
res = yield self.transport_client.create_group(
res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
)
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="join",
@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
"""Remove a user from a group
"""
if user_id == requester_user_id:
token = yield self.store.register_user_group_membership(
token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# retry if the group server is currently down.
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.remove_user_from_group(
res = await self.groups_server_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
else:
content["requester_user_id"] = requester_user_id
try:
res = yield self.transport_client.remove_user_from_group(
res = await self.transport_client.remove_user_from_group(
get_domain_from_id(group_id),
group_id,
requester_user_id,

View file

@ -626,8 +626,7 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
async def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
@ -647,7 +646,7 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
prev_state = await self.deduplicate_state_event(event, context)
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
@ -656,7 +655,7 @@ class EventCreationHandler(object):
)
return prev_state
yield self.handle_new_client_event(
await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@ -683,8 +682,7 @@ class EventCreationHandler(object):
return prev_event
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None
):
"""
@ -698,8 +696,8 @@ class EventCreationHandler(object):
# a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution
# taking longer.
with (yield self.limiter.queue(event_dict["room_id"])):
event, context = yield self.create_event(
with (await self.limiter.queue(event_dict["room_id"])):
event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
@ -709,7 +707,7 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
yield self.send_nonmember_event(
await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
return event
@ -770,8 +768,7 @@ class EventCreationHandler(object):
return (event, context)
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Processes a new event. This includes checking auth, persisting it,
@ -794,9 +791,9 @@ class EventCreationHandler(object):
):
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else:
room_version = yield self.store.get_room_version_id(event.room_id)
room_version = await self.store.get_room_version_id(event.room_id)
event_allowed = yield self.third_party_event_rules.check_event_allowed(
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@ -805,7 +802,7 @@ class EventCreationHandler(object):
)
try:
yield self.auth.check_from_context(room_version, event, context)
await self.auth.check_from_context(room_version, event, context)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
@ -818,7 +815,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content)
raise
yield self.action_generator.handle_push_actions_for_event(event, context)
await self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
@ -826,7 +823,7 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
yield self.send_event_to_master(
await self.send_event_to_master(
event_id=event.event_id,
store=self.store,
requester=requester,
@ -838,7 +835,7 @@ class EventCreationHandler(object):
success = True
return
yield self.persist_and_notify_client_event(
await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
@ -883,8 +880,7 @@ class EventCreationHandler(object):
Codes.BAD_ALIAS,
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Called when we have fully built the event, have already
@ -901,7 +897,7 @@ class EventCreationHandler(object):
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@ -913,11 +909,11 @@ class EventCreationHandler(object):
original_event and event.sender != original_event.sender
)
yield self.base_handler.ratelimit(
await self.base_handler.ratelimit(
requester, is_admin_redaction=is_admin_redaction
)
yield self.base_handler.maybe_kick_guest_users(event, context)
await self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Validate a newly added alias or newly added alt_aliases.
@ -927,7 +923,7 @@ class EventCreationHandler(object):
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
original_event = yield self.store.get_event(original_event_id)
original_event = await self.store.get_event(original_event_id)
if original_event:
original_alias = original_event.content.get("alias", None)
@ -937,7 +933,7 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
yield self._validate_canonical_alias(
await self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id
)
@ -957,7 +953,7 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
yield self._validate_canonical_alias(
await self._validate_canonical_alias(
directory_handler, alias_str, event.room_id
)
@ -969,7 +965,7 @@ class EventCreationHandler(object):
def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender
current_state_ids = yield context.get_current_state_ids()
current_state_ids = await context.get_current_state_ids()
state_to_include_ids = [
e_id
@ -978,7 +974,7 @@ class EventCreationHandler(object):
or k == (EventTypes.Member, event.sender)
]
state_to_include = yield self.store.get_events(state_to_include_ids)
state_to_include = await self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [
{
@ -996,8 +992,8 @@ class EventCreationHandler(object):
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield defer.ensureDeferred(
federation_handler.send_invite(invitee.domain, event)
returned_invite = await federation_handler.send_invite(
invitee.domain, event
)
event.unsigned.pop("room_state", None)
@ -1005,7 +1001,7 @@ class EventCreationHandler(object):
event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@ -1021,14 +1017,14 @@ class EventCreationHandler(object):
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = yield context.get_prev_state_ids()
auth_events_ids = yield self.auth.compute_auth_events(
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = yield self.store.get_room_version_id(event.room_id)
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if event_auth.check_redaction(
@ -1047,11 +1043,11 @@ class EventCreationHandler(object):
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
event, context=context
)
@ -1059,7 +1055,7 @@ class EventCreationHandler(object):
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
try:
@ -1083,13 +1079,12 @@ class EventCreationHandler(object):
except Exception:
logger.exception("Error bumping presence active time")
@defer.inlineCallbacks
def _send_dummy_events_to_fill_extremities(self):
async def _send_dummy_events_to_fill_extremities(self):
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
self._expire_rooms_to_exclude_from_dummy_event_insertion()
room_ids = yield self.store.get_rooms_with_many_extremities(
room_ids = await self.store.get_rooms_with_many_extremities(
min_count=10,
limit=5,
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
@ -1099,9 +1094,9 @@ class EventCreationHandler(object):
# For each room we need to find a joined member we can use to send
# the dummy event with.
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
members = yield self.state.get_current_users_in_room(
members = await self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids
)
dummy_event_sent = False
@ -1110,7 +1105,7 @@ class EventCreationHandler(object):
continue
requester = create_requester(user_id)
try:
event, context = yield self.create_event(
event, context = await self.create_event(
requester,
{
"type": "org.matrix.dummy_event",
@ -1123,7 +1118,7 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False
yield self.send_nonmember_event(
await self.send_nonmember_event(
requester, event, context, ratelimit=False
)
dummy_event_sent = True

View file

@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler):
return result["displayname"]
@defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
async def set_displayname(
self, target_user, requester, new_displayname, by_admin=False
):
"""Set the displayname of a user
Args:
@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname:
profile = yield self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.display_name:
raise SynapseError(
400,
@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user)
await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler):
return result["avatar_url"]
@defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user):
@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's avatar_url")
if not by_admin and not self.hs.config.enable_set_avatar_url:
profile = yield self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url:
raise SynapseError(
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user)
await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler):
return response
@defer.inlineCallbacks
def _update_join_states(self, requester, target_user):
async def _update_join_states(self, requester, target_user):
if not self.hs.is_mine(target_user):
return
yield self.ratelimit(requester)
await self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
handler = self.hs.get_room_member_handler()
try:
# Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data.
yield handler.update_membership(
await handler.update_membership(
requester,
target_user,
room_id,

View file

@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler):
"""Registers a new client on the server.
Args:
localpart : The local part of the user ID to register. If None,
localpart: The local part of the user ID to register. If None,
one will be generated.
password (unicode) : The password to assign to this user so they can
password (unicode): The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
user_type (str|None): type of user. One of the values from
@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler):
fail_count += 1
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
yield defer.ensureDeferred(self._auto_join_rooms(user_id))
else:
logger.info(
"Skipping auto-join for %s because consent is required at registration",
@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler):
return user_id
@defer.inlineCallbacks
def _auto_join_rooms(self, user_id):
async def _auto_join_rooms(self, user_id):
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler):
# that an auto-generated support or bot user is not a real user and will never be
# the user to create the room
should_auto_create_rooms = False
is_real_user = yield self.store.is_real_user(user_id)
is_real_user = await self.store.is_real_user(user_id)
if self.hs.config.autocreate_auto_join_rooms and is_real_user:
count = yield self.store.count_real_users()
count = await self.store.count_real_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
logger.info("Auto-joining %s to %s", user_id, r)
@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler):
# getting the RoomCreationHandler during init gives a dependency
# loop
yield self.hs.get_room_creation_handler().create_room(
await self.hs.get_room_creation_handler().create_room(
fake_requester,
config={
"preset": "public_chat",
@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
else:
yield self._join_user_to_room(fake_requester, r)
await self._join_user_to_room(fake_requester, r)
except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do.
@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
@defer.inlineCallbacks
def post_consent_actions(self, user_id):
async def post_consent_actions(self, user_id):
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
user_id (str): The user to join
"""
yield self._auto_join_rooms(user_id)
await self._auto_join_rooms(user_id)
@defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token):
@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id += 1
return str(id)
@defer.inlineCallbacks
def _join_user_to_room(self, requester, room_identifier):
async def _join_user_to_room(self, requester, room_identifier):
room_member_handler = self.hs.get_room_member_handler()
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
room_alias
)
room_id = room_id.to_string()
@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
yield room_member_handler.update_membership(
await room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler):
return (device_id, access_token)
@defer.inlineCallbacks
def post_registration_actions(self, user_id, auth_result, access_token):
async def post_registration_actions(self, user_id, auth_result, access_token):
"""A user has completed registration
Args:
@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler):
device, or None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
yield self._post_registration_client(
await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token
)
return
@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler):
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
yield self.store.upsert_monthly_active_user(user_id)
await self.store.upsert_monthly_active_user(user_id)
yield self._register_email_threepid(user_id, threepid, access_token)
await self._register_email_threepid(user_id, threepid, access_token)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(user_id, threepid)
await self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result:
yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
@defer.inlineCallbacks
def _on_user_consented(self, user_id, consent_version):
async def _on_user_consented(self, user_id, consent_version):
"""A user consented to the terms on registration
Args:
@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler):
consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
yield self.store.user_set_consent_version(user_id, consent_version)
yield self.post_consent_actions(user_id)
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
@defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token):

View file

@ -148,17 +148,16 @@ class RoomCreationHandler(BaseHandler):
return ret
@defer.inlineCallbacks
def _upgrade_room(
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
user_id = requester.user.to_string()
# start by allocating a new room id
r = yield self.store.get_room(old_room_id)
r = await self.store.get_room(old_room_id)
if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = yield self._generate_room_id(
new_room_id = await self._generate_room_id(
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
)
@ -169,7 +168,7 @@ class RoomCreationHandler(BaseHandler):
(
tombstone_event,
tombstone_context,
) = yield self.event_creation_handler.create_event(
) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Tombstone,
@ -183,12 +182,12 @@ class RoomCreationHandler(BaseHandler):
},
token_id=requester.access_token_id,
)
old_room_version = yield self.store.get_room_version_id(old_room_id)
yield self.auth.check_from_context(
old_room_version = await self.store.get_room_version_id(old_room_id)
await self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context
)
yield self.clone_existing_room(
await self.clone_existing_room(
requester,
old_room_id=old_room_id,
new_room_id=new_room_id,
@ -197,32 +196,31 @@ class RoomCreationHandler(BaseHandler):
)
# now send the tombstone
yield self.event_creation_handler.send_nonmember_event(
await self.event_creation_handler.send_nonmember_event(
requester, tombstone_event, tombstone_context
)
old_room_state = yield tombstone_context.get_current_state_ids()
old_room_state = await tombstone_context.get_current_state_ids()
# update any aliases
yield self._move_aliases_to_new_room(
await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
)
# Copy over user push rules, tags and migrate room directory state
yield self.room_member_handler.transfer_room_state_on_room_upgrade(
await self.room_member_handler.transfer_room_state_on_room_upgrade(
old_room_id, new_room_id
)
# finally, shut down the PLs in the old room, and update them in the new
# room.
yield self._update_upgraded_room_pls(
await self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state,
)
return new_room_id
@defer.inlineCallbacks
def _update_upgraded_room_pls(
async def _update_upgraded_room_pls(
self,
requester: Requester,
old_room_id: str,
@ -249,7 +247,7 @@ class RoomCreationHandler(BaseHandler):
)
return
old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)
old_room_pl_state = await self.store.get_event(old_room_pl_event_id)
# we try to stop regular users from speaking by setting the PL required
# to send regular events and invites to 'Moderator' level. That's normally
@ -278,7 +276,7 @@ class RoomCreationHandler(BaseHandler):
if updated:
try:
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@ -292,7 +290,7 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@ -304,8 +302,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
@defer.inlineCallbacks
def clone_existing_room(
async def clone_existing_room(
self,
requester: Requester,
old_room_id: str,
@ -338,7 +335,7 @@ class RoomCreationHandler(BaseHandler):
# Check if old room was non-federatable
# Get old room's create event
old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
# Check if the create event specified a non-federatable room
if not old_room_create_event.content.get("m.federate", True):
@ -361,11 +358,11 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.PowerLevels, ""),
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy)
)
# map from event_id to BaseEvent
old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
for k, old_event_id in iteritems(old_room_state_ids):
old_event = old_room_state_events.get(old_event_id)
@ -400,7 +397,7 @@ class RoomCreationHandler(BaseHandler):
if current_power_level < needed_power_level:
power_levels["users"][user_id] = needed_power_level
yield self._send_events_for_new_room(
await self._send_events_for_new_room(
requester,
new_room_id,
# we expect to override all the presets with initial_state, so this is
@ -412,12 +409,12 @@ class RoomCreationHandler(BaseHandler):
)
# Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
# map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events(
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
for k, old_event in iteritems(old_room_member_state_events):
@ -426,7 +423,7 @@ class RoomCreationHandler(BaseHandler):
"membership" in old_event.content
and old_event.content["membership"] == "ban"
):
yield self.room_member_handler.update_membership(
await self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event["state_key"]),
new_room_id,
@ -438,8 +435,7 @@ class RoomCreationHandler(BaseHandler):
# XXX invites/joins
# XXX 3pid invites
@defer.inlineCallbacks
def _move_aliases_to_new_room(
async def _move_aliases_to_new_room(
self,
requester: Requester,
old_room_id: str,
@ -448,13 +444,13 @@ class RoomCreationHandler(BaseHandler):
):
directory_handler = self.hs.get_handlers().directory_handler
aliases = yield self.store.get_aliases_for_room(old_room_id)
aliases = await self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
# first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end).
@ -472,7 +468,7 @@ class RoomCreationHandler(BaseHandler):
for alias_str in aliases:
alias = RoomAlias.from_string(alias_str)
try:
yield directory_handler.delete_association(requester, alias)
await directory_handler.delete_association(requester, alias)
removed_aliases.append(alias_str)
except SynapseError as e:
logger.warning("Unable to remove alias %s from old room: %s", alias, e)
@ -485,7 +481,7 @@ class RoomCreationHandler(BaseHandler):
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
yield directory_handler.create_association(
await directory_handler.create_association(
requester,
RoomAlias.from_string(alias),
new_room_id,
@ -502,7 +498,7 @@ class RoomCreationHandler(BaseHandler):
# alias event for the new room with a copy of the information.
try:
if canonical_alias_event:
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@ -518,8 +514,9 @@ class RoomCreationHandler(BaseHandler):
# we returned the new room to the client at this point.
logger.error("Unable to send updated alias events in new room: %s", e)
@defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
async def create_room(
self, requester, config, ratelimit=True, creator_join_profile=None
):
""" Creates a new room.
Args:
@ -547,7 +544,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
yield self.auth.check_auth_blocking(user_id)
await self.auth.check_auth_blocking(user_id)
if (
self._server_notices_mxid is not None
@ -556,11 +553,11 @@ class RoomCreationHandler(BaseHandler):
# allow the server notices mxid to create rooms
is_requester_admin = True
else:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
is_requester_admin = await self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create
# request.
event_allowed = yield self.third_party_event_rules.on_create_room(
event_allowed = await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin
)
if not event_allowed:
@ -574,7 +571,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
yield self.ratelimit(requester)
await self.ratelimit(requester)
room_version_id = config.get(
"room_version", self.config.default_room_version.identifier
@ -597,7 +594,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
mapping = yield self.store.get_association_from_room_alias(room_alias)
mapping = await self.store.get_association_from_room_alias(room_alias)
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
@ -612,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
if (
@ -631,13 +628,13 @@ class RoomCreationHandler(BaseHandler):
visibility = config.get("visibility", None)
is_public = visibility == "public"
room_id = yield self._generate_room_id(
room_id = await self._generate_room_id(
creator_id=user_id, is_public=is_public, room_version=room_version,
)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias:
yield directory_handler.create_association(
await directory_handler.create_association(
requester=requester,
room_id=room_id,
room_alias=room_alias,
@ -670,7 +667,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
yield self._send_events_for_new_room(
await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@ -684,7 +681,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@ -698,7 +695,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@ -716,7 +713,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
yield self.room_member_handler.update_membership(
await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@ -730,7 +727,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
yield self.hs.get_room_member_handler().do_3pid_invite(
await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@ -748,8 +745,7 @@ class RoomCreationHandler(BaseHandler):
return result
@defer.inlineCallbacks
def _send_events_for_new_room(
async def _send_events_for_new_room(
self,
creator, # A Requester object.
room_id,
@ -769,11 +765,10 @@ class RoomCreationHandler(BaseHandler):
return e
@defer.inlineCallbacks
def send(etype, content, **kwargs):
async def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
)
@ -784,10 +779,10 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
yield send(etype=EventTypes.Create, content=creation_content)
await send(etype=EventTypes.Create, content=creation_content)
logger.debug("Sending %s in new room", EventTypes.Member)
yield self.room_member_handler.update_membership(
await self.room_member_handler.update_membership(
creator,
creator.user,
room_id,
@ -800,7 +795,7 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
yield send(etype=EventTypes.PowerLevels, content=pl_content)
await send(etype=EventTypes.PowerLevels, content=pl_content)
else:
power_level_content = {
"users": {creator_id: 100},
@ -833,33 +828,33 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
yield send(etype=EventTypes.PowerLevels, content=power_level_content)
await send(etype=EventTypes.PowerLevels, content=power_level_content)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
yield send(
await send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
if (EventTypes.JoinRules, "") not in initial_state:
yield send(
await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
yield send(
await send(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
yield send(
await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
yield send(etype=etype, state_key=state_key, content=content)
await send(etype=etype, state_key=state_key, content=content)
@defer.inlineCallbacks
def _generate_room_id(

View file

@ -142,8 +142,7 @@ class RoomMemberHandler(object):
"""
raise NotImplementedError()
@defer.inlineCallbacks
def _local_membership_update(
async def _local_membership_update(
self,
requester,
target,
@ -164,7 +163,7 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
event, context = yield self.event_creation_handler.create_event(
event, context = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
@ -182,18 +181,18 @@ class RoomMemberHandler(object):
)
# Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_handler.deduplicate_state_event(
duplicate = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return duplicate
yield self.event_creation_handler.handle_new_client_event(
await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
)
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@ -203,15 +202,15 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target, room_id)
await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target, room_id)
await self._user_left_room(target, room_id)
return event
@ -253,8 +252,7 @@ class RoomMemberHandler(object):
for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
@defer.inlineCallbacks
def update_membership(
async def update_membership(
self,
requester,
target,
@ -269,8 +267,8 @@ class RoomMemberHandler(object):
):
key = (room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
with (await self.member_linearizer.queue(key)):
result = await self._update_membership(
requester,
target,
room_id,
@ -285,8 +283,7 @@ class RoomMemberHandler(object):
return result
@defer.inlineCallbacks
def _update_membership(
async def _update_membership(
self,
requester,
target,
@ -321,7 +318,7 @@ class RoomMemberHandler(object):
# if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join.
if third_party_signed is not None:
yield self.federation_handler.exchange_third_party_invite(
await self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
@ -332,7 +329,7 @@ class RoomMemberHandler(object):
remote_room_hosts = []
if effective_membership_state not in ("leave", "ban"):
is_blocked = yield self.store.is_room_blocked(room_id)
is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
@ -351,7 +348,7 @@ class RoomMemberHandler(object):
is_requester_admin = True
else:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
if self.config.block_non_admin_invites:
@ -370,9 +367,9 @@ class RoomMemberHandler(object):
if block_invite:
raise SynapseError(403, "Invites have been disabled on this server")
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids(
current_state_ids = await self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids
)
@ -381,7 +378,7 @@ class RoomMemberHandler(object):
# transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True)
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
@ -413,7 +410,7 @@ class RoomMemberHandler(object):
old_membership == Membership.INVITE
and effective_membership_state == Membership.LEAVE
):
is_blocked = yield self._is_server_notice_room(room_id)
is_blocked = await self._is_server_notice_room(room_id)
if is_blocked:
raise SynapseError(
http_client.FORBIDDEN,
@ -424,18 +421,18 @@ class RoomMemberHandler(object):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
is_host_in_room = yield self._is_host_in_room(current_state_ids)
is_host_in_room = await self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(current_state_ids)
guest_can_join = await self._can_guest_join(current_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self._get_inviter(target.to_string(), room_id)
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@ -443,13 +440,13 @@ class RoomMemberHandler(object):
profile = self.profile_handler
if not content_specified:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
content["displayname"] = await profile.get_displayname(target)
content["avatar_url"] = await profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
remote_join_response = yield self._remote_join(
remote_join_response = await self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
@ -458,7 +455,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = yield self._get_inviter(target.to_string(), room_id)
inviter = await self._get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
@ -472,12 +469,12 @@ class RoomMemberHandler(object):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
res = yield self._remote_reject_invite(
res = await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content,
)
return res
res = yield self._local_membership_update(
res = await self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@ -572,8 +569,7 @@ class RoomMemberHandler(object):
)
continue
@defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
async def send_membership_event(self, requester, event, context, ratelimit=True):
"""
Change the membership status of a user in a room.
@ -599,27 +595,27 @@ class RoomMemberHandler(object):
else:
requester = types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event(
prev_event = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if prev_event is not None:
return
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(prev_state_ids)
guest_can_join = await self._can_guest_join(prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if event.membership not in (Membership.LEAVE, Membership.BAN):
is_blocked = yield self.store.is_room_blocked(room_id)
is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_handler.handle_new_client_event(
await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
)
@ -633,15 +629,15 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target_user, room_id)
await self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target_user, room_id)
await self._user_left_room(target_user, room_id)
@defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
@ -699,8 +695,7 @@ class RoomMemberHandler(object):
if invite:
return UserID.from_string(invite.sender)
@defer.inlineCallbacks
def do_3pid_invite(
async def do_3pid_invite(
self,
room_id,
inviter,
@ -712,7 +707,7 @@ class RoomMemberHandler(object):
id_access_token=None,
):
if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN
@ -720,9 +715,9 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
yield self.base_handler.ratelimit(requester)
await self.base_handler.ratelimit(requester)
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
)
if not can_invite:
@ -737,16 +732,16 @@ class RoomMemberHandler(object):
403, "Looking up third-party identifiers is denied from this server"
)
invitee = yield self.identity_handler.lookup_3pid(
invitee = await self.identity_handler.lookup_3pid(
id_server, medium, address, id_access_token
)
if invitee:
yield self.update_membership(
await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
yield self._make_and_store_3pid_invite(
await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@ -757,8 +752,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
async def _make_and_store_3pid_invite(
self,
requester,
id_server,
@ -769,7 +763,7 @@ class RoomMemberHandler(object):
txn_id,
id_access_token=None,
):
room_state = yield self.state_handler.get_current_state(room_id)
room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
@ -807,7 +801,7 @@ class RoomMemberHandler(object):
public_keys,
fallback_public_key,
display_name,
) = yield self.identity_handler.ask_id_server_for_third_party_invite(
) = await self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester,
id_server=id_server,
medium=medium,
@ -823,7 +817,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity
@defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled:
# Fetch the room complexity
too_complex = yield self._is_remote_room_too_complex(
too_complex = await self._is_remote_room_too_complex(
room_id, remote_room_hosts
)
if too_complex is True:
@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield defer.ensureDeferred(
self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
yield self._user_joined_room(user, room_id)
await self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return
# Check again, but with the local state events
too_complex = yield self._is_local_room_too_complex(room_id)
too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
yield self.update_membership(
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
raise SynapseError(
@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
"""
return user_joined_room(self.distributor, target, room_id)
return defer.succeed(user_joined_room(self.distributor, target, room_id))
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
return user_left_room(self.distributor, target, room_id)
return defer.succeed(user_left_room(self.distributor, target, room_id))
@defer.inlineCallbacks
def forget(self, user, room_id):

View file

@ -27,6 +27,7 @@ import inspect
import logging
import threading
import types
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from typing_extensions import Literal
@ -287,6 +288,46 @@ class LoggingContext(object):
return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
This exists for backwards compatibility. ``current_context()`` should be
called directly.
Returns:
LoggingContext: the current logging context
"""
warnings.warn(
"synapse.logging.context.LoggingContext.current_context() is deprecated "
"in favor of synapse.logging.context.current_context().",
DeprecationWarning,
stacklevel=2,
)
return current_context()
@classmethod
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
This exists for backwards compatibility. ``set_current_context()`` should be
called directly.
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
warnings.warn(
"synapse.logging.context.LoggingContext.set_current_context() is deprecated "
"in favor of synapse.logging.context.set_current_context().",
DeprecationWarning,
stacklevel=2,
)
return set_current_context(context)
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
old_context = set_current_context(self)

View file

@ -273,10 +273,9 @@ class Notifier(object):
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
)
@defer.inlineCallbacks
def _notify_app_services(self, room_stream_id):
async def _notify_app_services(self, room_stream_id):
try:
yield self.appservice_handler.notify_interested_services(room_stream_id)
await self.appservice_handler.notify_interested_services(room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")
@ -475,20 +474,18 @@ class Notifier(object):
return result
@defer.inlineCallbacks
def _get_room_ids(self, user, explicit_room_id):
joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
async def _get_room_ids(self, user, explicit_room_id):
joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
return [explicit_room_id], True
if (yield self._is_world_readable(explicit_room_id)):
if await self._is_world_readable(explicit_room_id):
return [explicit_room_id], False
raise AuthError(403, "Non-joined access not allowed")
return joined_room_ids, True
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
state = yield self.state_handler.get_current_state(
async def _is_world_readable(self, room_id):
state = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:

View file

@ -16,6 +16,7 @@
import abc
import logging
import re
from inspect import signature
from typing import Dict, List, Tuple
from six import raise_from
@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`.
Requests are sent to master process by default, but can be sent to other
named processes by specifying an `instance_name` keyword argument.
Attributes:
NAME (str): A name for the endpoint, added to the path as well as used
@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
)
# We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
), "`instance_name` is a reserved paramater name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
), "`instance_name` is a reserved paramater name"
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod
@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(**kwargs):
def send_request(instance_name="master", **kwargs):
# Currently we only support sending requests to master process.
if instance_name != "master":
raise Exception("Unknown instance")
data = yield cls._serialize_payload(**kwargs)
url_args = [

View file

@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
def __init__(self, hs):
super().__init__(hs)
self._instance_name = hs.get_instance_name()
# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
upto_token = parse_integer(request, "upto_token", required=True)
updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token
self._instance_name, from_token, upto_token
)
return (

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Dict, Optional
from typing import Optional
import six
@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
self.hs = hs
def stream_positions(self) -> Dict[str, int]:
"""
Get the current positions of all the streams this store wants to subscribe to
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()

View file

@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
result["user_account_data"] = position
result["room_account_data"] = position
result["tag_account_data"] = position
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)

View file

@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)

View file

@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
# The user signature stream uses the same stream ID generator as the
# device list stream, so set them both to the device list ID
# generator's current token.
current_token = self._device_list_id_gen.get_current_token()
result[DeviceListsStream.NAME] = current_token
result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)

View file

@ -93,12 +93,6 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = -self._backfill_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)

View file

@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)

View file

@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
if self.hs.config.use_presence:
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)

View file

@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)

View file

@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()

View file

@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))

View file

@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
def stream_positions(self):
result = super(RoomStore, self).stream_positions()
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)

View file

@ -16,7 +16,7 @@
"""
import logging
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from twisted.internet.protocol import ReconnectingClientFactory
@ -86,37 +86,22 @@ class ReplicationDataHandler:
def __init__(self, store: BaseSlavedStore):
self.store = store
async def on_rdata(self, stream_name: str, token: int, rows: list):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
stream_name: name of the replication stream for this batch of rows
instance_name: the instance that wrote the rows.
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])

View file

@ -278,19 +278,24 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(self, stream_name: str, token: int, rows: list):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
instance_name: the instance that wrote the rows.
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
@ -314,15 +319,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(cmd.stream_name, [])
# Find where we previously streamed up to.
current_token = self._replication_data_handler.get_streams_to_replicate().get(
cmd.stream_name
)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s",
cmd.stream_name,
)
return
current_token = stream.current_token()
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@ -333,7 +330,9 @@ class ReplicationCommandHandler:
updates,
current_token,
missing_updates,
) = await stream.get_updates_since(current_token, cmd.token)
) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
# TODO: add some tests for this
@ -342,7 +341,10 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates):
await self.on_rdata(
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
cmd.stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)
# We've now caught up to position sent to us, notify handler.

View file

@ -61,6 +61,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
outbound_redis_connection = None # type: txredisapi.RedisProtocol
def connectionMade(self):
super().connectionMade()
logger.info("Connected to redis instance")
self.subscribe(self.stream_name)
self.send_command(ReplicateCommand())
@ -119,6 +120,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.warning("Unhandled command: %r", cmd)
def connectionLost(self, reason):
super().connectionLost(reason)
logger.info("Lost connection to redis instance")
self.handler.lost_connection(self)
@ -189,5 +191,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.handler = self.handler
p.outbound_redis_connection = self.outbound_redis_connection
p.stream_name = self.stream_name
p.password = self.password
return p

View file

@ -16,7 +16,7 @@
import logging
from collections import namedtuple
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
#
# The arguments are:
#
# * instance_name: the writer of the stream
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch.
#
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object):
@ -93,6 +94,7 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[], Token],
update_function: UpdateFunction,
):
@ -108,9 +110,11 @@ class Stream(object):
stream tokens. See the UpdateFunction type definition for more info.
Args:
local_instance_name: The instance name of the current process
current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above
"""
self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function
@ -135,14 +139,14 @@ class Stream(object):
"""
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token
self.local_instance_name, self.last_token, current_token
)
self.last_token = current_token
return updates, current_token, limited
async def get_updates_since(
self, from_token: Token, upto_token: Token
self, instance_name: str, from_token: Token, upto_token: Token
) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
@ -160,19 +164,19 @@ class Stream(object):
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
)
return updates, upto_token, limited
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(from_token, upto_token, limit):
async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
from_token: int, upto_token: int, limit: int
instance_name: str, from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult:
result = await client(
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
instance_name=instance_name,
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
)
return result["updates"], result["upto_token"], result["limited"]
@ -226,6 +233,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@ -261,7 +269,9 @@ class PresenceStream(Stream):
# Query master process
update_function = make_http_update_function(hs, self.NAME)
super().__init__(store.get_current_presence_token, update_function)
super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
)
class TypingStream(Stream):
@ -284,7 +294,9 @@ class TypingStream(Stream):
# Query master process
update_function = make_http_update_function(hs, self.NAME)
super().__init__(typing_handler.get_current_token, update_function)
super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
)
class ReceiptsStream(Stream):
@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts),
)
@ -322,14 +335,16 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super(PushRulesStream, self).__init__(
self._current_token, self._update_function
hs.get_instance_name(), self._current_token, self._update_function
)
def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
async def _update_function(
self, instance_name: str, from_token: Token, to_token: Token, limit: int
):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False
@ -356,6 +371,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@ -387,6 +403,7 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
)
@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms),
)
@ -432,6 +450,7 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages),
)
@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags),
)
@ -487,6 +508,7 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function),
)
@ -517,6 +539,7 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes),
)
@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes

View file

@ -118,11 +118,17 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
super().__init__(
self._store.get_current_events_token, self._update_function,
hs.get_instance_name(),
self._store.get_current_events_token,
self._update_function,
)
async def _update_function(
self, from_token: Token, current_token: Token, target_row_count: int
self,
instance_name: str,
from_token: Token,
current_token: Token,
target_row_count: int,
) -> StreamUpdateResult:
# the events stream merges together three separate sources:

View file

@ -48,8 +48,8 @@ class FederationStream(Stream):
current_token = lambda: 0
update_function = self._stub_update_function
super().__init__(current_token, update_function)
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
async def _stub_update_function(from_token, upto_token, limit):
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False

View file

@ -16,8 +16,6 @@ import logging
from six import iteritems, string_types
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder
from synapse.config import ConfigError
@ -59,8 +57,7 @@ class ConsentServerNotices(object):
self._consent_uri_builder = ConsentURIBuilder(hs.config)
@defer.inlineCallbacks
def maybe_send_server_notice_to_user(self, user_id):
async def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, and does so if so
Args:
@ -78,7 +75,7 @@ class ConsentServerNotices(object):
return
self._users_in_progress.add(user_id)
try:
u = yield self._store.get_user_by_id(user_id)
u = await self._store.get_user_by_id(user_id)
if u["is_guest"] and not self._send_to_guests:
# don't send to guests
@ -100,8 +97,8 @@ class ConsentServerNotices(object):
content = copy_with_str_subst(
self._server_notice_content, {"consent_uri": consent_uri}
)
yield self._server_notices_manager.send_notice(user_id, content)
yield self._store.user_set_consent_server_notice_sent(
await self._server_notices_manager.send_notice(user_id, content)
await self._store.user_set_consent_server_notice_sent(
user_id, self._current_consent_version
)
except SynapseError as e:

View file

@ -16,8 +16,6 @@ import logging
from six import iteritems
from twisted.internet import defer
from synapse.api.constants import (
EventTypes,
LimitBlockingTypes,
@ -50,8 +48,7 @@ class ResourceLimitsServerNotices(object):
self._notifier = hs.get_notifier()
@defer.inlineCallbacks
def maybe_send_server_notice_to_user(self, user_id):
async def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, this will be true in
two cases.
1. The server has reached its limit does not reflect this
@ -74,13 +71,13 @@ class ResourceLimitsServerNotices(object):
# Don't try and send server notices unless they've been enabled
return
timestamp = yield self._store.user_last_seen_monthly_active(user_id)
timestamp = await self._store.user_last_seen_monthly_active(user_id)
if timestamp is None:
# This user will be blocked from receiving the notice anyway.
# In practice, not sure we can ever get here
return
room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user(
room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
user_id
)
@ -88,10 +85,10 @@ class ResourceLimitsServerNotices(object):
logger.warning("Failed to get server notices room")
return
yield self._check_and_set_tags(user_id, room_id)
await self._check_and_set_tags(user_id, room_id)
# Determine current state of room
currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id)
currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
limit_msg = None
limit_type = None
@ -99,7 +96,7 @@ class ResourceLimitsServerNotices(object):
# Normally should always pass in user_id to check_auth_blocking
# if you have it, but in this case are checking what would happen
# to other users if they were to arrive.
yield self._auth.check_auth_blocking()
await self._auth.check_auth_blocking()
except ResourceLimitError as e:
limit_msg = e.msg
limit_type = e.limit_type
@ -112,22 +109,21 @@ class ResourceLimitsServerNotices(object):
# We have hit the MAU limit, but MAU alerting is disabled:
# reset room if necessary and return
if currently_blocked:
self._remove_limit_block_notification(user_id, ref_events)
await self._remove_limit_block_notification(user_id, ref_events)
return
if currently_blocked and not limit_msg:
# Room is notifying of a block, when it ought not to be.
yield self._remove_limit_block_notification(user_id, ref_events)
await self._remove_limit_block_notification(user_id, ref_events)
elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be.
yield self._apply_limit_block_notification(
await self._apply_limit_block_notification(
user_id, limit_msg, limit_type
)
except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e)
@defer.inlineCallbacks
def _remove_limit_block_notification(self, user_id, ref_events):
async def _remove_limit_block_notification(self, user_id, ref_events):
"""Utility method to remove limit block notifications from the server
notices room.
@ -137,12 +133,13 @@ class ResourceLimitsServerNotices(object):
limit blocking and need to be preserved.
"""
content = {"pinned": ref_events}
yield self._server_notices_manager.send_notice(
await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, ""
)
@defer.inlineCallbacks
def _apply_limit_block_notification(self, user_id, event_body, event_limit_type):
async def _apply_limit_block_notification(
self, user_id, event_body, event_limit_type
):
"""Utility method to apply limit block notifications in the server
notices room.
@ -159,17 +156,16 @@ class ResourceLimitsServerNotices(object):
"admin_contact": self._config.admin_contact,
"limit_type": event_limit_type,
}
event = yield self._server_notices_manager.send_notice(
event = await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Message
)
content = {"pinned": [event.event_id]}
yield self._server_notices_manager.send_notice(
await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, ""
)
@defer.inlineCallbacks
def _check_and_set_tags(self, user_id, room_id):
async def _check_and_set_tags(self, user_id, room_id):
"""
Since server notices rooms were originally not with tags,
important to check that tags have been set correctly
@ -177,20 +173,19 @@ class ResourceLimitsServerNotices(object):
user_id(str): the user in question
room_id(str): the server notices room for that user
"""
tags = yield self._store.get_tags_for_room(user_id, room_id)
tags = await self._store.get_tags_for_room(user_id, room_id)
need_to_set_tag = True
if tags:
if SERVER_NOTICE_ROOM_TAG in tags:
# tag already present, nothing to do here
need_to_set_tag = False
if need_to_set_tag:
max_id = yield self._store.add_tag_to_room(
max_id = await self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@defer.inlineCallbacks
def _is_room_currently_blocked(self, room_id):
async def _is_room_currently_blocked(self, room_id):
"""
Determines if the room is currently blocked
@ -198,7 +193,7 @@ class ResourceLimitsServerNotices(object):
room_id(str): The room id of the server notices room
Returns:
Deferred[Tuple[bool, List]]:
bool: Is the room currently blocked
list: The list of pinned events that are unrelated to limit blocking
This list can be used as a convenience in the case where the block
@ -208,7 +203,7 @@ class ResourceLimitsServerNotices(object):
currently_blocked = False
pinned_state_event = None
try:
pinned_state_event = yield self._state.get_current_state(
pinned_state_event = await self._state.get_current_state(
room_id, event_type=EventTypes.Pinned
)
except AuthError:
@ -219,7 +214,7 @@ class ResourceLimitsServerNotices(object):
if pinned_state_event is not None:
referenced_events = list(pinned_state_event.content.get("pinned", []))
events = yield self._store.get_events(referenced_events)
events = await self._store.get_events(referenced_events)
for event_id, event in iteritems(events):
if event.type != EventTypes.Message:
continue

View file

@ -14,11 +14,9 @@
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@ -51,8 +49,7 @@ class ServerNoticesManager(object):
"""
return self._config.server_notices_mxid is not None
@defer.inlineCallbacks
def send_notice(
async def send_notice(
self, user_id, event_content, type=EventTypes.Message, state_key=None
):
"""Send a notice to the given user
@ -68,8 +65,8 @@ class ServerNoticesManager(object):
Returns:
Deferred[FrozenEvent]
"""
room_id = yield self.get_or_create_notice_room_for_user(user_id)
yield self.maybe_invite_user_to_room(user_id, room_id)
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid
requester = create_requester(system_mxid)
@ -86,13 +83,13 @@ class ServerNoticesManager(object):
if state_key is not None:
event_dict["state_key"] = state_key
res = yield self._event_creation_handler.create_and_send_nonmember_event(
res = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
)
return res
@cachedInlineCallbacks()
def get_or_create_notice_room_for_user(self, user_id):
@cached()
async def get_or_create_notice_room_for_user(self, user_id):
"""Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't
@ -109,7 +106,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
for room in rooms:
@ -118,7 +115,7 @@ class ServerNoticesManager(object):
# be joined. This is kinda deliberate, in that if somebody somehow
# manages to invite the system user to a room, that doesn't make it
# the server notices room.
user_ids = yield self._store.get_users_in_room(room.room_id)
user_ids = await self._store.get_users_in_room(room.room_id)
if self.server_notices_mxid in user_ids:
# we found a room which our user shares with the system notice
# user
@ -146,7 +143,7 @@ class ServerNoticesManager(object):
}
requester = create_requester(self.server_notices_mxid)
info = yield self._room_creation_handler.create_room(
info = await self._room_creation_handler.create_room(
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
@ -158,7 +155,7 @@ class ServerNoticesManager(object):
)
room_id = info["room_id"]
max_id = yield self._store.add_tag_to_room(
max_id = await self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@ -166,8 +163,7 @@ class ServerNoticesManager(object):
logger.info("Created server notices room %s for %s", room_id, user_id)
return room_id
@defer.inlineCallbacks
def maybe_invite_user_to_room(self, user_id: str, room_id: str):
async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
"""Invite the given user to the given server room, unless the user has already
joined or been invited to it.
@ -179,14 +175,14 @@ class ServerNoticesManager(object):
# Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them.
joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
for room in joined_rooms:
if room.room_id == room_id:
return
yield self._room_member_handler.update_membership(
await self._room_member_handler.update_membership(
requester=requester,
target=UserID.from_string(user_id),
room_id=room_id,

View file

@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.server_notices.consent_server_notices import ConsentServerNotices
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
@ -36,18 +34,16 @@ class ServerNoticesSender(object):
ResourceLimitsServerNotices(hs),
)
@defer.inlineCallbacks
def on_user_syncing(self, user_id):
async def on_user_syncing(self, user_id):
"""Called when the user performs a sync operation.
Args:
user_id (str): mxid of user who synced
"""
for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id)
await sn.maybe_send_server_notice_to_user(user_id)
@defer.inlineCallbacks
def on_user_ip(self, user_id):
async def on_user_ip(self, user_id):
"""Called on the master when a worker process saw a client request.
Args:
@ -57,4 +53,4 @@ class ServerNoticesSender(object):
# we check for notices to send to the user in on_user_ip as well as
# in on_user_syncing
for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id)
await sn.maybe_send_server_notice_to_user(user_id)

View file

@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
@defer.inlineCallbacks
def is_server_admin(self, user):
async def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.
Args:
@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool):
true iff the user is a server admin, false otherwise.
"""
res = yield self.db.simple_select_one_onecol(
res = await self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",

View file

@ -35,9 +35,13 @@ DELETE FROM background_updates WHERE update_name IN (
'populate_stats_cleanup'
);
-- this relies on current_state_events.membership having been populated, so add
-- a dependency on current_state_events_membership.
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_rooms', '{}', '');
('populate_stats_process_rooms', '{}', 'current_state_events_membership');
-- this also relies on current_state_events.membership having been populated, but
-- we get that as a side-effect of depending on populate_stats_process_rooms.
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_users', '{}', 'populate_stats_process_rooms');

View file

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
from tests.test_utils.event_injection import create_event
class TestEventContext(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
self.room_id = self.helper.create_room_as(tok=self.user_tok)
def test_serialize_deserialize_msg(self):
"""Test that an EventContext for a message event is the same after
serialize/deserialize.
"""
event, context = create_event(
self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
)
self._check_serialize_deserialize(event, context)
def test_serialize_deserialize_state_no_prev(self):
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
event, context = create_event(
self.hs,
room_id=self.room_id,
type="m.test",
sender=self.user_id,
state_key="",
)
self._check_serialize_deserialize(event, context)
def test_serialize_deserialize_state_prev(self):
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
event, context = create_event(
self.hs,
room_id=self.room_id,
type="m.room.member",
sender=self.user_id,
state_key=self.user_id,
content={"membership": "leave"},
)
self._check_serialize_deserialize(event, context)
def _check_serialize_deserialize(self, event, context):
serialized = self.get_success(context.serialize(event, self.store))
d_context = EventContext.deserialize(self.storage, serialized)
self.assertEqual(context.state_group, d_context.state_group)
self.assertEqual(context.rejected, d_context.rejected)
self.assertEqual(
context.state_group_before_event, d_context.state_group_before_event
)
self.assertEqual(context.prev_group, d_context.prev_group)
self.assertEqual(context.delta_ids, d_context.delta_ids)
self.assertEqual(context.app_service, d_context.app_service)
self.assertEqual(
self.get_success(context.get_current_state_ids()),
self.get_success(d_context.get_current_state_ids()),
)
self.assertEqual(
self.get_success(context.get_prev_state_ids()),
self.get_success(d_context.get_prev_state_ids()),
)

View file

@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_name(self):
yield self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
yield defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank Jr.",
)
# Set displayname again
yield self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
yield defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
)
)
self.assertEquals(
@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase):
)
# Setting displayname a second time is forbidden
d = self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
d = defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
)
yield self.assertFailure(d, SynapseError)
@defer.inlineCallbacks
def test_set_my_name_noauth(self):
d = self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
d = defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
)
)
yield self.assertFailure(d, AuthError)
@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_avatar(self):
yield self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
yield defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
)
)
self.assertEquals(
@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase):
)
# Set avatar again
yield self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/me.png",
yield defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/me.png",
)
)
self.assertEquals(
@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase):
)
# Set avatar a second time is forbidden
d = self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
d = defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
)
)
yield self.assertFailure(d, SynapseError)

View file

@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_real_user = Mock(return_value=False)
self.store.is_real_user = Mock(return_value=defer.succeed(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=1)
self.store.is_real_user = Mock(return_value=True)
self.store.count_real_users = Mock(return_value=defer.succeed(1))
self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=2)
self.store.is_real_user = Mock(return_value=True)
self.store.count_real_users = Mock(return_value=defer.succeed(2))
self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
@defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
yield self.hs.get_auth().check_auth_blocking()
await self.hs.get_auth().check_auth_blocking()
need_register = True
try:
yield self.handler.check_username(localpart)
await self.handler.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
token = self.macaroon_generator.generate_access_token(user_id)
if need_register:
yield self.handler.register_with_store(
await self.handler.register_with_store(
user_id=user_id,
password_hash=password_hash,
create_profile_with_displayname=user.localpart,
)
else:
yield defer.ensureDeferred(
self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
)
await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(
await self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None
)
if displayname is not None:
# logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.hs.get_profile_handler().set_displayname(
await self.hs.get_profile_handler().set_displayname(
user, requester, displayname, by_admin=True
)

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple
import attr
@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())
return TestReplicationDataHandler(self.worker_hs)
def reconnect(self):
if self._client_transport:
@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
class TestReplicationDataHandler(ReplicationDataHandler):
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self, store: BaseSlavedStore):
super().__init__(store)
# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
def __init__(self, hs: HomeServer):
super().__init__(hs)
# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
def get_streams_to_replicate(self):
return self.stream_positions
async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token
@attr.s()
class OneShotRequestFactory:

View file

@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.user_tok = self.login("u1", "pass")
self.reconnect()
self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect()
self.replicate()
# now we should have received all the expected rows in the right order.
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
#
# we expect:
#
@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# of the states that got reverted.
# - two rows for state2
received_rows = self.test_handler.received_rdata_rows
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
# first check the first two rows, which should be state1
@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for

View file

@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self):
self.reconnect()
# make the client subscribe to the receipts stream
self.test_handler.stream_positions.update({"receipts": 0})
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# there should be one RDATA command
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# We should now have caught up and get the missing data
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))

View file

@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.reconnect()
# make the client subscribe to the typing stream
self.test_handler.stream_positions.update({"typing": 0})
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
self.reactor.advance(0)
@ -50,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
@ -77,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]

View file

@ -55,26 +55,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self._rlsn._server_notices_manager.send_notice = Mock()
self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock())
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test"
# self.server_notices_mxid = "@server:test"
# self.server_notices_mxid_display_name = None
# self.server_notices_mxid_avatar_url = None
# self.server_notices_room_name = "Server Notices"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
returnValue=""
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock()
self._rlsn._store.get_tags_for_room = Mock(return_value={})
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
self.hs.config.admin_contact = "mailto:user@test.com"
def test_maybe_send_server_notice_to_user_flag_off(self):
@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
self._send_notice.assert_called_once()
@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo")
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
mock_event = Mock(
@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, but should have one
"""
self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo")
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
)
),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self.assertTrue(self._send_notice.call_count == 0)
self.assertEqual(self._send_notice.call_count, 0)
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
)
),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
)
),
)
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
return_value=defer.succeed((True, []))
)
@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
# Call the function multiple times to ensure we only send the notice once
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))

View file

@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
room = room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
room = ensureDeferred(
room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
)
)
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"]

View file

@ -14,12 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection
from tests.test_utils import get_awaitable_result
@ -75,6 +76,23 @@ def inject_event(
"""
test_reactor = hs.get_reactor()
event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event
def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
test_reactor = hs.get_reactor()
if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0)
@ -89,8 +107,4 @@ def inject_event(
test_reactor.advance(0)
event, context = get_awaitable_result(d)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event
return event, context