async/await is_server_admin (#7363)

This commit is contained in:
Andrew Morgan 2020-05-01 15:15:36 +01:00 committed by GitHub
parent 2e8955f4a6
commit 6b22921b19
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 410 additions and 438 deletions

1
changelog.d/7363.misc Normal file
View file

@ -0,0 +1 @@
Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await.

View file

@ -537,8 +537,7 @@ class Auth(object):
return defer.succeed(auth_ids) return defer.succeed(auth_ids)
@defer.inlineCallbacks async def check_can_change_room_list(self, room_id: str, user: UserID):
def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the """Determine whether the user is allowed to edit the room's entry in the
published room list. published room list.
@ -547,17 +546,17 @@ class Auth(object):
user user
""" """
is_admin = yield self.is_server_admin(user) is_admin = await self.is_server_admin(user)
if is_admin: if is_admin:
return True return True
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_in_room(room_id, user_id) await self.check_user_in_room(room_id, user_id)
# We currently require the user is a "moderator" in the room. We do this # We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the # by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events # m.room.canonical_alias events
power_level_event = yield self.state.get_current_state( power_level_event = await self.state.get_current_state(
room_id, EventTypes.PowerLevels, "" room_id, EventTypes.PowerLevels, ""
) )

View file

@ -976,14 +976,13 @@ class FederationClient(FederationBase):
return signed_events return signed_events
@defer.inlineCallbacks async def forward_third_party_invite(self, destinations, room_id, event_dict):
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
try: try:
yield self.transport_layer.exchange_third_party_invite( await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict destination=destination, room_id=room_id, event_dict=event_dict
) )
return None return None

View file

@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise NotImplementedError() raise NotImplementedError()
@defer.inlineCallbacks async def remove_user_from_group(
def remove_user_from_group(self, group_id, user_id, requester_user_id, content): self, group_id, user_id, requester_user_id, content
):
"""Remove a user from the group; either a user is leaving or an admin """Remove a user from the group; either a user is leaving or an admin
kicked them. kicked them.
""" """
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_kick = False is_kick = False
if requester_user_id != user_id: if requester_user_id != user_id:
is_admin = yield self.store.is_user_admin_in_group( is_admin = await self.store.is_user_admin_in_group(
group_id, requester_user_id group_id, requester_user_id
) )
if not is_admin: if not is_admin:
@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
is_kick = True is_kick = True
yield self.store.remove_user_from_group(group_id, user_id) await self.store.remove_user_from_group(group_id, user_id)
if is_kick: if is_kick:
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler() groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {}) await groups_local.user_removed_from_group(group_id, user_id, {})
else: else:
yield self.transport_client.remove_user_from_group_notification( await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {} get_domain_from_id(user_id), group_id, user_id, {}
) )
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id) await self.store.maybe_delete_remote_profile_cache(user_id)
# Delete group if the last user has left # Delete group if the last user has left
users = yield self.store.get_users_in_group(group_id, include_private=True) users = await self.store.get_users_in_group(group_id, include_private=True)
if not users: if not users:
yield self.store.delete_group(group_id) await self.store.delete_group(group_id)
return {} return {}
@defer.inlineCallbacks async def create_group(self, group_id, requester_user_id, content):
def create_group(self, group_id, requester_user_id, content): group = await self.check_group_is_ours(group_id, requester_user_id)
group = yield self.check_group_is_ours(group_id, requester_user_id)
logger.info("Attempting to create group with ID: %r", group_id) logger.info("Attempting to create group with ID: %r", group_id)
@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if group: if group:
raise SynapseError(400, "Group already exists") raise SynapseError(400, "Group already exists")
is_admin = yield self.auth.is_server_admin( is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id) UserID.from_string(requester_user_id)
) )
if not is_admin: if not is_admin:
@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
long_description = profile.get("long_description") long_description = profile.get("long_description")
user_profile = content.get("user_profile", {}) user_profile = content.get("user_profile", {})
yield self.store.create_group( await self.store.create_group(
group_id, group_id,
requester_user_id, requester_user_id,
name=name, name=name,
@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if not self.hs.is_mine_id(requester_user_id): if not self.hs.is_mine_id(requester_user_id):
remote_attestation = content["attestation"] remote_attestation = content["attestation"]
yield self.attestations.verify_attestation( await self.attestations.verify_attestation(
remote_attestation, user_id=requester_user_id, group_id=group_id remote_attestation, user_id=requester_user_id, group_id=group_id
) )
@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
local_attestation = None local_attestation = None
remote_attestation = None remote_attestation = None
yield self.store.add_user_to_group( await self.store.add_user_to_group(
group_id, group_id,
requester_user_id, requester_user_id,
is_admin=True, is_admin=True,
@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
) )
if not self.hs.is_mine_id(requester_user_id): if not self.hs.is_mine_id(requester_user_id):
yield self.store.add_remote_profile_cache( await self.store.add_remote_profile_cache(
requester_user_id, requester_user_id,
displayname=user_profile.get("displayname"), displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"), avatar_url=user_profile.get("avatar_url"),
@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"group_id": group_id} return {"group_id": group_id}
@defer.inlineCallbacks async def delete_group(self, group_id, requester_user_id):
def delete_group(self, group_id, requester_user_id):
"""Deletes a group, kicking out all current members. """Deletes a group, kicking out all current members.
Only group admins or server admins can call this request Only group admins or server admins can call this request
@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
Deferred Deferred
""" """
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
# Only server admins or group admins can delete groups. # Only server admins or group admins can delete groups.
is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id)
if not is_admin: if not is_admin:
is_admin = yield self.auth.is_server_admin( is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id) UserID.from_string(requester_user_id)
) )
@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise SynapseError(403, "User is not an admin") raise SynapseError(403, "User is not an admin")
# Before deleting the group lets kick everyone out of it # Before deleting the group lets kick everyone out of it
users = yield self.store.get_users_in_group(group_id, include_private=True) users = await self.store.get_users_in_group(group_id, include_private=True)
@defer.inlineCallbacks async def _kick_user_from_group(user_id):
def _kick_user_from_group(user_id):
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler() groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {}) await groups_local.user_removed_from_group(group_id, user_id, {})
else: else:
yield self.transport_client.remove_user_from_group_notification( await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {} get_domain_from_id(user_id), group_id, user_id, {}
) )
yield self.store.maybe_delete_remote_profile_cache(user_id) await self.store.maybe_delete_remote_profile_cache(user_id)
# We kick users out in the order of: # We kick users out in the order of:
# 1. Non-admins # 1. Non-admins
@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
else: else:
non_admins.append(u["user_id"]) non_admins.append(u["user_id"])
yield concurrently_execute(_kick_user_from_group, non_admins, 10) await concurrently_execute(_kick_user_from_group, non_admins, 10)
yield concurrently_execute(_kick_user_from_group, admins, 10) await concurrently_execute(_kick_user_from_group, admins, 10)
yield _kick_user_from_group(requester_user_id) await _kick_user_from_group(requester_user_id)
yield self.store.delete_group(group_id) await self.store.delete_group(group_id)
def _parse_join_policy_from_contents(content): def _parse_join_policy_from_contents(content):

View file

@ -126,30 +126,28 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now)) retry_after_ms=int(1000 * (time_allowed - time_now))
) )
@defer.inlineCallbacks async def maybe_kick_guest_users(self, event, context=None):
def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it. # Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller. # Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden") guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join": if guest_access != "can_join":
if context: if context:
current_state_ids = yield context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
current_state = yield self.store.get_events( current_state = await self.store.get_events(
list(current_state_ids.values()) list(current_state_ids.values())
) )
else: else:
current_state = yield self.state_handler.get_current_state( current_state = await self.state_handler.get_current_state(
event.room_id event.room_id
) )
current_state = list(current_state.values()) current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state) logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state) await self.kick_guest_users(current_state)
@defer.inlineCallbacks async def kick_guest_users(self, current_state):
def kick_guest_users(self, current_state):
for member_event in current_state: for member_event in current_state:
try: try:
if member_event.type != EventTypes.Member: if member_event.type != EventTypes.Member:
@ -180,7 +178,7 @@ class BaseHandler(object):
# homeserver. # homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True) requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler() handler = self.hs.get_room_member_handler()
yield handler.update_membership( await handler.update_membership(
requester, requester,
target_user, target_user,
member_event.room_id, member_event.room_id,

View file

@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler):
room_alias, room_id, servers, creator=creator room_alias, room_id, servers, creator=creator
) )
@defer.inlineCallbacks async def create_association(
def create_association(
self, self,
requester: Requester, requester: Requester,
room_alias: RoomAlias, room_alias: RoomAlias,
@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler):
else: else:
# Server admins are not subject to the same constraints as normal # Server admins are not subject to the same constraints as normal
# users when creating an alias (e.g. being in the room). # users when creating an alias (e.g. being in the room).
is_admin = yield self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
if (self.require_membership and check_membership) and not is_admin: if (self.require_membership and check_membership) and not is_admin:
rooms_for_user = yield self.store.get_rooms_for_user(user_id) rooms_for_user = await self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user: if room_id not in rooms_for_user:
raise AuthError( raise AuthError(
403, "You must be in the room to create an alias for it" 403, "You must be in the room to create an alias for it"
@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule? # per alias creation rule?
raise SynapseError(403, "Not allowed to create alias") raise SynapseError(403, "Not allowed to create alias")
can_create = yield self.can_modify_alias(room_alias, user_id=user_id) can_create = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_create: if not can_create:
raise AuthError( raise AuthError(
400, 400,
@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE, errcode=Codes.EXCLUSIVE,
) )
yield self._create_association(room_alias, room_id, servers, creator=user_id) await self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks async def delete_association(self, requester: Requester, room_alias: RoomAlias):
def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory """Remove an alias from the directory
(this is only meant for human users; AS users should call (this is only meant for human users; AS users should call
@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler):
user_id = requester.user.to_string() user_id = requester.user.to_string()
try: try:
can_delete = yield self._user_can_delete_alias(room_alias, user_id) can_delete = await self._user_can_delete_alias(room_alias, user_id)
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
raise NotFoundError("Unknown room alias") raise NotFoundError("Unknown room alias")
@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete: if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.") raise AuthError(403, "You don't have permission to delete the alias.")
can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete: if not can_delete:
raise SynapseError( raise SynapseError(
400, 400,
@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE, errcode=Codes.EXCLUSIVE,
) )
room_id = yield self._delete_association(room_alias) room_id = await self._delete_association(room_alias)
try: try:
yield self._update_canonical_alias(requester, user_id, room_id, room_alias) await self._update_canonical_alias(requester, user_id, room_id, room_alias)
except AuthError as e: except AuthError as e:
logger.info("Failed to update alias events: %s", e) logger.info("Failed to update alias events: %s", e)
@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND, Codes.NOT_FOUND,
) )
@defer.inlineCallbacks async def _update_canonical_alias(
def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
): ):
""" """
Send an updated canonical alias event if the removed alias was set as Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field. the canonical alias or listed in the alt_aliases field.
""" """
alias_event = yield self.state.get_current_state( alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, "" room_id, EventTypes.CanonicalAlias, ""
) )
@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler):
del content["alt_aliases"] del content["alt_aliases"]
if send_update: if send_update:
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.CanonicalAlias, "type": EventTypes.CanonicalAlias,
@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler):
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
return defer.succeed(True) return defer.succeed(True)
@defer.inlineCallbacks async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias. """Determine whether a user can delete an alias.
One of the following must be true: One of the following must be true:
@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler):
for the current room. for the current room.
""" """
creator = yield self.store.get_room_alias_creator(alias.to_string()) creator = await self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id: if creator is not None and creator == user_id:
return True return True
# Resolve the alias to the corresponding room. # Resolve the alias to the corresponding room.
room_mapping = yield self.get_association(alias) room_mapping = await self.get_association(alias)
room_id = room_mapping["room_id"] room_id = room_mapping["room_id"]
if not room_id: if not room_id:
return False return False
res = yield self.auth.check_can_change_room_list( res = await self.auth.check_can_change_room_list(
room_id, UserID.from_string(user_id) room_id, UserID.from_string(user_id)
) )
return res return res
@defer.inlineCallbacks async def edit_published_room_list(
def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str self, requester: Requester, room_id: str, visibility: str
): ):
"""Edit the entry of the room in the published room list. """Edit the entry of the room in the published room list.
@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler):
403, "This user is not permitted to publish rooms to the room list" 403, "This user is not permitted to publish rooms to the room list"
) )
room = yield self.store.get_room(room_id) room = await self.store.get_room(room_id)
if room is None: if room is None:
raise SynapseError(400, "Unknown room") raise SynapseError(400, "Unknown room")
can_change_room_list = yield self.auth.check_can_change_room_list( can_change_room_list = await self.auth.check_can_change_room_list(
room_id, requester.user room_id, requester.user
) )
if not can_change_room_list: if not can_change_room_list:
@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler):
making_public = visibility == "public" making_public = visibility == "public"
if making_public: if making_public:
room_aliases = yield self.store.get_aliases_for_room(room_id) room_aliases = await self.store.get_aliases_for_room(room_id)
canonical_alias = yield self.store.get_canonical_alias_for_room(room_id) canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
if canonical_alias: if canonical_alias:
room_aliases.append(canonical_alias) room_aliases.append(canonical_alias)
@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule? # per alias creation rule?
raise SynapseError(403, "Not allowed to publish room") raise SynapseError(403, "Not allowed to publish room")
yield self.store.set_room_is_public(room_id, making_public) await self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks @defer.inlineCallbacks
def edit_published_appservice_room_list( def edit_published_appservice_room_list(

View file

@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler):
"missing": [e.event_id for e in missing_locals], "missing": [e.event_id for e in missing_locals],
} }
@defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed self, sender_user_id, target_user_id, room_id, signed
): ):
third_party_invite = {"signed": signed} third_party_invite = {"signed": signed}
@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler):
"state_key": target_user_id, "state_key": target_user_id,
} }
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if await self.auth.check_host_in_room(room_id, self.hs.hostname):
room_version = yield self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict) builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder) EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = await self.event_creation_handler.create_new_client_event(
builder=builder builder=builder
) )
event_allowed = yield self.third_party_event_rules.check_event_allowed( event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context event, context
) )
if not event_allowed: if not event_allowed:
@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler):
403, "This event is not allowed in this context", Codes.FORBIDDEN 403, "This event is not allowed in this context", Codes.FORBIDDEN
) )
event, context = yield self.add_display_name_to_third_party_invite( event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context room_version, event_dict, event, context
) )
@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname event.internal_metadata.send_on_behalf_of = self.hs.hostname
try: try:
yield self.auth.check_from_context(room_version, event, context) await self.auth.check_from_context(room_version, event, context)
except AuthError as e: except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e) logger.warning("Denying new third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, context) await self._check_signature(event, context)
# We retrieve the room member handler here as to not cause a cyclic dependency # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler() member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context) await member_handler.send_membership_event(None, event, context)
else: else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
yield self.federation_client.forward_third_party_invite( await self.federation_client.forward_third_party_invite(
destinations, room_id, event_dict destinations, room_id, event_dict
) )

View file

@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy") set_group_join_policy = _create_rerouter("set_group_join_policy")
@defer.inlineCallbacks async def create_group(self, group_id, user_id, content):
def create_group(self, group_id, user_id, content):
"""Create a group """Create a group
""" """
logger.info("Asking to create group with ID: %r", group_id) logger.info("Asking to create group with ID: %r", group_id)
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
res = yield self.groups_server_handler.create_group( res = await self.groups_server_handler.create_group(
group_id, user_id, content group_id, user_id, content
) )
local_attestation = None local_attestation = None
@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = self.attestations.create_attestation(group_id, user_id) local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation content["attestation"] = local_attestation
content["user_profile"] = yield self.profile_handler.get_profile(user_id) content["user_profile"] = await self.profile_handler.get_profile(user_id)
try: try:
res = yield self.transport_client.create_group( res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content get_domain_from_id(group_id), group_id, user_id, content
) )
except HttpResponseException as e: except HttpResponseException as e:
@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
raise SynapseError(502, "Failed to contact group server") raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"] remote_attestation = res["attestation"]
yield self.attestations.verify_attestation( await self.attestations.verify_attestation(
remote_attestation, remote_attestation,
group_id=group_id, group_id=group_id,
user_id=user_id, user_id=user_id,
@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
) )
is_publicised = content.get("publicise", False) is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership( token = await self.store.register_user_group_membership(
group_id, group_id,
user_id, user_id,
membership="join", membership="join",
@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile} return {"state": "invite", "user_profile": user_profile}
@defer.inlineCallbacks async def remove_user_from_group(
def remove_user_from_group(self, group_id, user_id, requester_user_id, content): self, group_id, user_id, requester_user_id, content
):
"""Remove a user from a group """Remove a user from a group
""" """
if user_id == requester_user_id: if user_id == requester_user_id:
token = yield self.store.register_user_group_membership( token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave" group_id, user_id, membership="leave"
) )
self.notifier.on_new_event("groups_key", token, users=[user_id]) self.notifier.on_new_event("groups_key", token, users=[user_id])
@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# retry if the group server is currently down. # retry if the group server is currently down.
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
res = yield self.groups_server_handler.remove_user_from_group( res = await self.groups_server_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content group_id, user_id, requester_user_id, content
) )
else: else:
content["requester_user_id"] = requester_user_id content["requester_user_id"] = requester_user_id
try: try:
res = yield self.transport_client.remove_user_from_group( res = await self.transport_client.remove_user_from_group(
get_domain_from_id(group_id), get_domain_from_id(group_id),
group_id, group_id,
requester_user_id, requester_user_id,

View file

@ -626,8 +626,7 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@defer.inlineCallbacks async def send_nonmember_event(self, requester, event, context, ratelimit=True):
def send_nonmember_event(self, requester, event, context, ratelimit=True):
""" """
Persists and notifies local clients and federation of an event. Persists and notifies local clients and federation of an event.
@ -647,7 +646,7 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state(): if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context) prev_state = await self.deduplicate_state_event(event, context)
if prev_state is not None: if prev_state is not None:
logger.info( logger.info(
"Not bothering to persist state event %s duplicated by %s", "Not bothering to persist state event %s duplicated by %s",
@ -656,7 +655,7 @@ class EventCreationHandler(object):
) )
return prev_state return prev_state
yield self.handle_new_client_event( await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit requester=requester, event=event, context=context, ratelimit=ratelimit
) )
@ -683,8 +682,7 @@ class EventCreationHandler(object):
return prev_event return prev_event
return return
@defer.inlineCallbacks async def create_and_send_nonmember_event(
def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None self, requester, event_dict, ratelimit=True, txn_id=None
): ):
""" """
@ -698,8 +696,8 @@ class EventCreationHandler(object):
# a situation where event persistence can't keep up, causing # a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution # extremities to pile up, which in turn leads to state resolution
# taking longer. # taking longer.
with (yield self.limiter.queue(event_dict["room_id"])): with (await self.limiter.queue(event_dict["room_id"])):
event, context = yield self.create_event( event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
) )
@ -709,7 +707,7 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here" spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN) raise SynapseError(403, spam_error, Codes.FORBIDDEN)
yield self.send_nonmember_event( await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit requester, event, context, ratelimit=ratelimit
) )
return event return event
@ -770,8 +768,7 @@ class EventCreationHandler(object):
return (event, context) return (event, context)
@measure_func("handle_new_client_event") @measure_func("handle_new_client_event")
@defer.inlineCallbacks async def handle_new_client_event(
def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[] self, requester, event, context, ratelimit=True, extra_users=[]
): ):
"""Processes a new event. This includes checking auth, persisting it, """Processes a new event. This includes checking auth, persisting it,
@ -794,9 +791,9 @@ class EventCreationHandler(object):
): ):
room_version = event.content.get("room_version", RoomVersions.V1.identifier) room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else: else:
room_version = yield self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
event_allowed = yield self.third_party_event_rules.check_event_allowed( event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context event, context
) )
if not event_allowed: if not event_allowed:
@ -805,7 +802,7 @@ class EventCreationHandler(object):
) )
try: try:
yield self.auth.check_from_context(room_version, event, context) await self.auth.check_from_context(room_version, event, context)
except AuthError as err: except AuthError as err:
logger.warning("Denying new event %r because %s", event, err) logger.warning("Denying new event %r because %s", event, err)
raise err raise err
@ -818,7 +815,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)
raise raise
yield self.action_generator.handle_push_actions_for_event(event, context) await self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead. # hack around with a try/finally instead.
@ -826,7 +823,7 @@ class EventCreationHandler(object):
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
if self.config.worker_app: if self.config.worker_app:
yield self.send_event_to_master( await self.send_event_to_master(
event_id=event.event_id, event_id=event.event_id,
store=self.store, store=self.store,
requester=requester, requester=requester,
@ -838,7 +835,7 @@ class EventCreationHandler(object):
success = True success = True
return return
yield self.persist_and_notify_client_event( await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users requester, event, context, ratelimit=ratelimit, extra_users=extra_users
) )
@ -883,8 +880,7 @@ class EventCreationHandler(object):
Codes.BAD_ALIAS, Codes.BAD_ALIAS,
) )
@defer.inlineCallbacks async def persist_and_notify_client_event(
def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[] self, requester, event, context, ratelimit=True, extra_users=[]
): ):
"""Called when we have fully built the event, have already """Called when we have fully built the event, have already
@ -901,7 +897,7 @@ class EventCreationHandler(object):
# user is actually admin or not). # user is actually admin or not).
is_admin_redaction = False is_admin_redaction = False
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event( original_event = await self.store.get_event(
event.redacts, event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False, get_prev_content=False,
@ -913,11 +909,11 @@ class EventCreationHandler(object):
original_event and event.sender != original_event.sender original_event and event.sender != original_event.sender
) )
yield self.base_handler.ratelimit( await self.base_handler.ratelimit(
requester, is_admin_redaction=is_admin_redaction requester, is_admin_redaction=is_admin_redaction
) )
yield self.base_handler.maybe_kick_guest_users(event, context) await self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Validate a newly added alias or newly added alt_aliases. # Validate a newly added alias or newly added alt_aliases.
@ -927,7 +923,7 @@ class EventCreationHandler(object):
original_event_id = event.unsigned.get("replaces_state") original_event_id = event.unsigned.get("replaces_state")
if original_event_id: if original_event_id:
original_event = yield self.store.get_event(original_event_id) original_event = await self.store.get_event(original_event_id)
if original_event: if original_event:
original_alias = original_event.content.get("alias", None) original_alias = original_event.content.get("alias", None)
@ -937,7 +933,7 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None) room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias: if room_alias_str and room_alias_str != original_alias:
yield self._validate_canonical_alias( await self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id directory_handler, room_alias_str, event.room_id
) )
@ -957,7 +953,7 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases: if new_alt_aliases:
for alias_str in new_alt_aliases: for alias_str in new_alt_aliases:
yield self._validate_canonical_alias( await self._validate_canonical_alias(
directory_handler, alias_str, event.room_id directory_handler, alias_str, event.room_id
) )
@ -969,7 +965,7 @@ class EventCreationHandler(object):
def is_inviter_member_event(e): def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender return e.type == EventTypes.Member and e.sender == event.sender
current_state_ids = yield context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
state_to_include_ids = [ state_to_include_ids = [
e_id e_id
@ -978,7 +974,7 @@ class EventCreationHandler(object):
or k == (EventTypes.Member, event.sender) or k == (EventTypes.Member, event.sender)
] ]
state_to_include = yield self.store.get_events(state_to_include_ids) state_to_include = await self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [ event.unsigned["invite_room_state"] = [
{ {
@ -996,8 +992,8 @@ class EventCreationHandler(object):
# way? If we have been invited by a remote server, we need # way? If we have been invited by a remote server, we need
# to get them to sign the event. # to get them to sign the event.
returned_invite = yield defer.ensureDeferred( returned_invite = await federation_handler.send_invite(
federation_handler.send_invite(invitee.domain, event) invitee.domain, event
) )
event.unsigned.pop("room_state", None) event.unsigned.pop("room_state", None)
@ -1005,7 +1001,7 @@ class EventCreationHandler(object):
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event( original_event = await self.store.get_event(
event.redacts, event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False, get_prev_content=False,
@ -1021,14 +1017,14 @@ class EventCreationHandler(object):
if original_event.room_id != event.room_id: if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room") raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = yield self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if event_auth.check_redaction( if event_auth.check_redaction(
@ -1047,11 +1043,11 @@ class EventCreationHandler(object):
event.internal_metadata.recheck_redaction = False event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
if prev_state_ids: if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden") raise AuthError(403, "Changing the room create event is forbidden")
event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
event, context=context event, context=context
) )
@ -1059,7 +1055,7 @@ class EventCreationHandler(object):
# If there's an expiry timestamp on the event, schedule its expiry. # If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event) self._message_handler.maybe_schedule_expiry(event)
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify(): def _notify():
try: try:
@ -1083,13 +1079,12 @@ class EventCreationHandler(object):
except Exception: except Exception:
logger.exception("Error bumping presence active time") logger.exception("Error bumping presence active time")
@defer.inlineCallbacks async def _send_dummy_events_to_fill_extremities(self):
def _send_dummy_events_to_fill_extremities(self):
"""Background task to send dummy events into rooms that have a large """Background task to send dummy events into rooms that have a large
number of extremities number of extremities
""" """
self._expire_rooms_to_exclude_from_dummy_event_insertion() self._expire_rooms_to_exclude_from_dummy_event_insertion()
room_ids = yield self.store.get_rooms_with_many_extremities( room_ids = await self.store.get_rooms_with_many_extremities(
min_count=10, min_count=10,
limit=5, limit=5,
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
@ -1099,9 +1094,9 @@ class EventCreationHandler(object):
# For each room we need to find a joined member we can use to send # For each room we need to find a joined member we can use to send
# the dummy event with. # the dummy event with.
latest_event_ids = yield self.store.get_prev_events_for_room(room_id) latest_event_ids = await self.store.get_prev_events_for_room(room_id)
members = yield self.state.get_current_users_in_room( members = await self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids room_id, latest_event_ids=latest_event_ids
) )
dummy_event_sent = False dummy_event_sent = False
@ -1110,7 +1105,7 @@ class EventCreationHandler(object):
continue continue
requester = create_requester(user_id) requester = create_requester(user_id)
try: try:
event, context = yield self.create_event( event, context = await self.create_event(
requester, requester,
{ {
"type": "org.matrix.dummy_event", "type": "org.matrix.dummy_event",
@ -1123,7 +1118,7 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False event.internal_metadata.proactively_send = False
yield self.send_nonmember_event( await self.send_nonmember_event(
requester, event, context, ratelimit=False requester, event, context, ratelimit=False
) )
dummy_event_sent = True dummy_event_sent = True

View file

@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler):
return result["displayname"] return result["displayname"]
@defer.inlineCallbacks async def set_displayname(
def set_displayname(self, target_user, requester, new_displayname, by_admin=False): self, target_user, requester, new_displayname, by_admin=False
):
"""Set the displayname of a user """Set the displayname of a user
Args: Args:
@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname: if not by_admin and not self.hs.config.enable_set_displayname:
profile = yield self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
if profile.display_name: if profile.display_name:
raise SynapseError( raise SynapseError(
400, 400,
@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(target_user)
yield self.store.set_profile_displayname(target_user.localpart, new_displayname) await self.store.set_profile_displayname(target_user.localpart, new_displayname)
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile target_user.to_string(), profile
) )
yield self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_avatar_url(self, target_user): def get_avatar_url(self, target_user):
@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler):
return result["avatar_url"] return result["avatar_url"]
@defer.inlineCallbacks async def set_avatar_url(
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): self, target_user, requester, new_avatar_url, by_admin=False
):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
if not by_admin and not self.hs.config.enable_set_avatar_url: if not by_admin and not self.hs.config.enable_set_avatar_url:
profile = yield self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url: if profile.avatar_url:
raise SynapseError( raise SynapseError(
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(target_user)
yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile target_user.to_string(), profile
) )
yield self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_profile_query(self, args): def on_profile_query(self, args):
@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler):
return response return response
@defer.inlineCallbacks async def _update_join_states(self, requester, target_user):
def _update_join_states(self, requester, target_user):
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
return return
yield self.ratelimit(requester) await self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids: for room_id in room_ids:
handler = self.hs.get_room_member_handler() handler = self.hs.get_room_member_handler()
try: try:
# Assume the target_user isn't a guest, # Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data. # because we don't let guests set profile or avatar data.
yield handler.update_membership( await handler.update_membership(
requester, requester,
target_user, target_user,
room_id, room_id,

View file

@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler):
fail_count += 1 fail_count += 1
if not self.hs.config.user_consent_at_registration: if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id) yield defer.ensureDeferred(self._auto_join_rooms(user_id))
else: else:
logger.info( logger.info(
"Skipping auto-join for %s because consent is required at registration", "Skipping auto-join for %s because consent is required at registration",
@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler):
return user_id return user_id
@defer.inlineCallbacks async def _auto_join_rooms(self, user_id):
def _auto_join_rooms(self, user_id):
"""Automatically joins users to auto join rooms - creating the room in the first place """Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created. if the user is the first to be created.
@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler):
# that an auto-generated support or bot user is not a real user and will never be # that an auto-generated support or bot user is not a real user and will never be
# the user to create the room # the user to create the room
should_auto_create_rooms = False should_auto_create_rooms = False
is_real_user = yield self.store.is_real_user(user_id) is_real_user = await self.store.is_real_user(user_id)
if self.hs.config.autocreate_auto_join_rooms and is_real_user: if self.hs.config.autocreate_auto_join_rooms and is_real_user:
count = yield self.store.count_real_users() count = await self.store.count_real_users()
should_auto_create_rooms = count == 1 should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms: for r in self.hs.config.auto_join_rooms:
logger.info("Auto-joining %s to %s", user_id, r) logger.info("Auto-joining %s to %s", user_id, r)
@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler):
# getting the RoomCreationHandler during init gives a dependency # getting the RoomCreationHandler during init gives a dependency
# loop # loop
yield self.hs.get_room_creation_handler().create_room( await self.hs.get_room_creation_handler().create_room(
fake_requester, fake_requester,
config={ config={
"preset": "public_chat", "preset": "public_chat",
@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler):
ratelimit=False, ratelimit=False,
) )
else: else:
yield self._join_user_to_room(fake_requester, r) await self._join_user_to_room(fake_requester, r)
except ConsentNotGivenError as e: except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though # Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do. # moving away from bare excepts is a good thing to do.
@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler):
except Exception as e: except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e) logger.error("Failed to join new user to %r: %r", r, e)
@defer.inlineCallbacks async def post_consent_actions(self, user_id):
def post_consent_actions(self, user_id):
"""A series of registration actions that can only be carried out once consent """A series of registration actions that can only be carried out once consent
has been granted has been granted
Args: Args:
user_id (str): The user to join user_id (str): The user to join
""" """
yield self._auto_join_rooms(user_id) await self._auto_join_rooms(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token): def appservice_register(self, user_localpart, as_token):
@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id += 1 self._next_generated_user_id += 1
return str(id) return str(id)
@defer.inlineCallbacks async def _join_user_to_room(self, requester, room_identifier):
def _join_user_to_room(self, requester, room_identifier):
room_member_handler = self.hs.get_room_member_handler() room_member_handler = self.hs.get_room_member_handler()
if RoomID.is_valid(room_identifier): if RoomID.is_valid(room_identifier):
room_id = room_identifier room_id = room_identifier
elif RoomAlias.is_valid(room_identifier): elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier) room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
room_alias room_alias
) )
room_id = room_id.to_string() room_id = room_id.to_string()
@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler):
400, "%s was not legal room ID or room alias" % (room_identifier,) 400, "%s was not legal room ID or room alias" % (room_identifier,)
) )
yield room_member_handler.update_membership( await room_member_handler.update_membership(
requester=requester, requester=requester,
target=requester.user, target=requester.user,
room_id=room_id, room_id=room_id,
@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler):
return (device_id, access_token) return (device_id, access_token)
@defer.inlineCallbacks async def post_registration_actions(self, user_id, auth_result, access_token):
def post_registration_actions(self, user_id, auth_result, access_token):
"""A user has completed registration """A user has completed registration
Args: Args:
@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler):
device, or None if `inhibit_login` enabled. device, or None if `inhibit_login` enabled.
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
yield self._post_registration_client( await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token user_id=user_id, auth_result=auth_result, access_token=access_token
) )
return return
@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler):
if is_threepid_reserved( if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid self.hs.config.mau_limits_reserved_threepids, threepid
): ):
yield self.store.upsert_monthly_active_user(user_id) await self.store.upsert_monthly_active_user(user_id)
yield self._register_email_threepid(user_id, threepid, access_token) await self._register_email_threepid(user_id, threepid, access_token)
if auth_result and LoginType.MSISDN in auth_result: if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN] threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(user_id, threepid) await self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result: if auth_result and LoginType.TERMS in auth_result:
yield self._on_user_consented(user_id, self.hs.config.user_consent_version) await self._on_user_consented(user_id, self.hs.config.user_consent_version)
@defer.inlineCallbacks async def _on_user_consented(self, user_id, consent_version):
def _on_user_consented(self, user_id, consent_version):
"""A user consented to the terms on registration """A user consented to the terms on registration
Args: Args:
@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler):
consented to. consented to.
""" """
logger.info("%s has consented to the privacy policy", user_id) logger.info("%s has consented to the privacy policy", user_id)
yield self.store.user_set_consent_version(user_id, consent_version) await self.store.user_set_consent_version(user_id, consent_version)
yield self.post_consent_actions(user_id) await self.post_consent_actions(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token): def _register_email_threepid(self, user_id, threepid, token):

View file

@ -148,17 +148,16 @@ class RoomCreationHandler(BaseHandler):
return ret return ret
@defer.inlineCallbacks async def _upgrade_room(
def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion self, requester: Requester, old_room_id: str, new_version: RoomVersion
): ):
user_id = requester.user.to_string() user_id = requester.user.to_string()
# start by allocating a new room id # start by allocating a new room id
r = yield self.store.get_room(old_room_id) r = await self.store.get_room(old_room_id)
if r is None: if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,)) raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = yield self._generate_room_id( new_room_id = await self._generate_room_id(
creator_id=user_id, is_public=r["is_public"], room_version=new_version, creator_id=user_id, is_public=r["is_public"], room_version=new_version,
) )
@ -169,7 +168,7 @@ class RoomCreationHandler(BaseHandler):
( (
tombstone_event, tombstone_event,
tombstone_context, tombstone_context,
) = yield self.event_creation_handler.create_event( ) = await self.event_creation_handler.create_event(
requester, requester,
{ {
"type": EventTypes.Tombstone, "type": EventTypes.Tombstone,
@ -183,12 +182,12 @@ class RoomCreationHandler(BaseHandler):
}, },
token_id=requester.access_token_id, token_id=requester.access_token_id,
) )
old_room_version = yield self.store.get_room_version_id(old_room_id) old_room_version = await self.store.get_room_version_id(old_room_id)
yield self.auth.check_from_context( await self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context old_room_version, tombstone_event, tombstone_context
) )
yield self.clone_existing_room( await self.clone_existing_room(
requester, requester,
old_room_id=old_room_id, old_room_id=old_room_id,
new_room_id=new_room_id, new_room_id=new_room_id,
@ -197,32 +196,31 @@ class RoomCreationHandler(BaseHandler):
) )
# now send the tombstone # now send the tombstone
yield self.event_creation_handler.send_nonmember_event( await self.event_creation_handler.send_nonmember_event(
requester, tombstone_event, tombstone_context requester, tombstone_event, tombstone_context
) )
old_room_state = yield tombstone_context.get_current_state_ids() old_room_state = await tombstone_context.get_current_state_ids()
# update any aliases # update any aliases
yield self._move_aliases_to_new_room( await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state requester, old_room_id, new_room_id, old_room_state
) )
# Copy over user push rules, tags and migrate room directory state # Copy over user push rules, tags and migrate room directory state
yield self.room_member_handler.transfer_room_state_on_room_upgrade( await self.room_member_handler.transfer_room_state_on_room_upgrade(
old_room_id, new_room_id old_room_id, new_room_id
) )
# finally, shut down the PLs in the old room, and update them in the new # finally, shut down the PLs in the old room, and update them in the new
# room. # room.
yield self._update_upgraded_room_pls( await self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state, requester, old_room_id, new_room_id, old_room_state,
) )
return new_room_id return new_room_id
@defer.inlineCallbacks async def _update_upgraded_room_pls(
def _update_upgraded_room_pls(
self, self,
requester: Requester, requester: Requester,
old_room_id: str, old_room_id: str,
@ -249,7 +247,7 @@ class RoomCreationHandler(BaseHandler):
) )
return return
old_room_pl_state = yield self.store.get_event(old_room_pl_event_id) old_room_pl_state = await self.store.get_event(old_room_pl_event_id)
# we try to stop regular users from speaking by setting the PL required # we try to stop regular users from speaking by setting the PL required
# to send regular events and invites to 'Moderator' level. That's normally # to send regular events and invites to 'Moderator' level. That's normally
@ -278,7 +276,7 @@ class RoomCreationHandler(BaseHandler):
if updated: if updated:
try: try:
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.PowerLevels, "type": EventTypes.PowerLevels,
@ -292,7 +290,7 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e: except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e) logger.warning("Unable to update PLs in old room: %s", e)
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.PowerLevels, "type": EventTypes.PowerLevels,
@ -304,8 +302,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False, ratelimit=False,
) )
@defer.inlineCallbacks async def clone_existing_room(
def clone_existing_room(
self, self,
requester: Requester, requester: Requester,
old_room_id: str, old_room_id: str,
@ -338,7 +335,7 @@ class RoomCreationHandler(BaseHandler):
# Check if old room was non-federatable # Check if old room was non-federatable
# Get old room's create event # Get old room's create event
old_room_create_event = yield self.store.get_create_event_for_room(old_room_id) old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
# Check if the create event specified a non-federatable room # Check if the create event specified a non-federatable room
if not old_room_create_event.content.get("m.federate", True): if not old_room_create_event.content.get("m.federate", True):
@ -361,11 +358,11 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.PowerLevels, ""), (EventTypes.PowerLevels, ""),
) )
old_room_state_ids = yield self.store.get_filtered_current_state_ids( old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy) old_room_id, StateFilter.from_types(types_to_copy)
) )
# map from event_id to BaseEvent # map from event_id to BaseEvent
old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) old_room_state_events = await self.store.get_events(old_room_state_ids.values())
for k, old_event_id in iteritems(old_room_state_ids): for k, old_event_id in iteritems(old_room_state_ids):
old_event = old_room_state_events.get(old_event_id) old_event = old_room_state_events.get(old_event_id)
@ -400,7 +397,7 @@ class RoomCreationHandler(BaseHandler):
if current_power_level < needed_power_level: if current_power_level < needed_power_level:
power_levels["users"][user_id] = needed_power_level power_levels["users"][user_id] = needed_power_level
yield self._send_events_for_new_room( await self._send_events_for_new_room(
requester, requester,
new_room_id, new_room_id,
# we expect to override all the presets with initial_state, so this is # we expect to override all the presets with initial_state, so this is
@ -412,12 +409,12 @@ class RoomCreationHandler(BaseHandler):
) )
# Transfer membership events # Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
) )
# map from event_id to BaseEvent # map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events( old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values() old_room_member_state_ids.values()
) )
for k, old_event in iteritems(old_room_member_state_events): for k, old_event in iteritems(old_room_member_state_events):
@ -426,7 +423,7 @@ class RoomCreationHandler(BaseHandler):
"membership" in old_event.content "membership" in old_event.content
and old_event.content["membership"] == "ban" and old_event.content["membership"] == "ban"
): ):
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester, requester,
UserID.from_string(old_event["state_key"]), UserID.from_string(old_event["state_key"]),
new_room_id, new_room_id,
@ -438,8 +435,7 @@ class RoomCreationHandler(BaseHandler):
# XXX invites/joins # XXX invites/joins
# XXX 3pid invites # XXX 3pid invites
@defer.inlineCallbacks async def _move_aliases_to_new_room(
def _move_aliases_to_new_room(
self, self,
requester: Requester, requester: Requester,
old_room_id: str, old_room_id: str,
@ -448,13 +444,13 @@ class RoomCreationHandler(BaseHandler):
): ):
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
aliases = yield self.store.get_aliases_for_room(old_room_id) aliases = await self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias. # check to see if we have a canonical alias.
canonical_alias_event = None canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id: if canonical_alias_event_id:
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id) canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
# first we try to remove the aliases from the old room (we suppress sending # first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end). # the room_aliases event until the end).
@ -472,7 +468,7 @@ class RoomCreationHandler(BaseHandler):
for alias_str in aliases: for alias_str in aliases:
alias = RoomAlias.from_string(alias_str) alias = RoomAlias.from_string(alias_str)
try: try:
yield directory_handler.delete_association(requester, alias) await directory_handler.delete_association(requester, alias)
removed_aliases.append(alias_str) removed_aliases.append(alias_str)
except SynapseError as e: except SynapseError as e:
logger.warning("Unable to remove alias %s from old room: %s", alias, e) logger.warning("Unable to remove alias %s from old room: %s", alias, e)
@ -485,7 +481,7 @@ class RoomCreationHandler(BaseHandler):
# we can now add any aliases we successfully removed to the new room. # we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases: for alias in removed_aliases:
try: try:
yield directory_handler.create_association( await directory_handler.create_association(
requester, requester,
RoomAlias.from_string(alias), RoomAlias.from_string(alias),
new_room_id, new_room_id,
@ -502,7 +498,7 @@ class RoomCreationHandler(BaseHandler):
# alias event for the new room with a copy of the information. # alias event for the new room with a copy of the information.
try: try:
if canonical_alias_event: if canonical_alias_event:
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.CanonicalAlias, "type": EventTypes.CanonicalAlias,
@ -518,8 +514,9 @@ class RoomCreationHandler(BaseHandler):
# we returned the new room to the client at this point. # we returned the new room to the client at this point.
logger.error("Unable to send updated alias events in new room: %s", e) logger.error("Unable to send updated alias events in new room: %s", e)
@defer.inlineCallbacks async def create_room(
def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): self, requester, config, ratelimit=True, creator_join_profile=None
):
""" Creates a new room. """ Creates a new room.
Args: Args:
@ -547,7 +544,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(user_id)
if ( if (
self._server_notices_mxid is not None self._server_notices_mxid is not None
@ -556,11 +553,11 @@ class RoomCreationHandler(BaseHandler):
# allow the server notices mxid to create rooms # allow the server notices mxid to create rooms
is_requester_admin = True is_requester_admin = True
else: else:
is_requester_admin = yield self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create # Check whether the third party rules allows/changes the room create
# request. # request.
event_allowed = yield self.third_party_event_rules.on_create_room( event_allowed = await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin requester, config, is_requester_admin=is_requester_admin
) )
if not event_allowed: if not event_allowed:
@ -574,7 +571,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit: if ratelimit:
yield self.ratelimit(requester) await self.ratelimit(requester)
room_version_id = config.get( room_version_id = config.get(
"room_version", self.config.default_room_version.identifier "room_version", self.config.default_room_version.identifier
@ -597,7 +594,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(400, "Invalid characters in room alias") raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
mapping = yield self.store.get_association_from_room_alias(room_alias) mapping = await self.store.get_association_from_room_alias(room_alias)
if mapping: if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
@ -612,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
except Exception: except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,)) raise SynapseError(400, "Invalid user_id: %s" % (i,))
yield self.event_creation_handler.assert_accepted_privacy_policy(requester) await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override") power_level_content_override = config.get("power_level_content_override")
if ( if (
@ -631,13 +628,13 @@ class RoomCreationHandler(BaseHandler):
visibility = config.get("visibility", None) visibility = config.get("visibility", None)
is_public = visibility == "public" is_public = visibility == "public"
room_id = yield self._generate_room_id( room_id = await self._generate_room_id(
creator_id=user_id, is_public=is_public, room_version=room_version, creator_id=user_id, is_public=is_public, room_version=room_version,
) )
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
if room_alias: if room_alias:
yield directory_handler.create_association( await directory_handler.create_association(
requester=requester, requester=requester,
room_id=room_id, room_id=room_id,
room_alias=room_alias, room_alias=room_alias,
@ -670,7 +667,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content # override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier creation_content["room_version"] = room_version.identifier
yield self._send_events_for_new_room( await self._send_events_for_new_room(
requester, requester,
room_id, room_id,
preset_config=preset_config, preset_config=preset_config,
@ -684,7 +681,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Name, "type": EventTypes.Name,
@ -698,7 +695,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Topic, "type": EventTypes.Topic,
@ -716,7 +713,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct: if is_direct:
content["is_direct"] = is_direct content["is_direct"] = is_direct
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester, requester,
UserID.from_string(invitee), UserID.from_string(invitee),
room_id, room_id,
@ -730,7 +727,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"] address = invite_3pid["address"]
medium = invite_3pid["medium"] medium = invite_3pid["medium"]
yield self.hs.get_room_member_handler().do_3pid_invite( await self.hs.get_room_member_handler().do_3pid_invite(
room_id, room_id,
requester.user, requester.user,
medium, medium,
@ -748,8 +745,7 @@ class RoomCreationHandler(BaseHandler):
return result return result
@defer.inlineCallbacks async def _send_events_for_new_room(
def _send_events_for_new_room(
self, self,
creator, # A Requester object. creator, # A Requester object.
room_id, room_id,
@ -769,11 +765,10 @@ class RoomCreationHandler(BaseHandler):
return e return e
@defer.inlineCallbacks async def send(etype, content, **kwargs):
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs) event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype) logger.debug("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False creator, event, ratelimit=False
) )
@ -784,10 +779,10 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id}) creation_content.update({"creator": creator_id})
yield send(etype=EventTypes.Create, content=creation_content) await send(etype=EventTypes.Create, content=creation_content)
logger.debug("Sending %s in new room", EventTypes.Member) logger.debug("Sending %s in new room", EventTypes.Member)
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
creator, creator,
creator.user, creator.user,
room_id, room_id,
@ -800,7 +795,7 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room. # of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None: if pl_content is not None:
yield send(etype=EventTypes.PowerLevels, content=pl_content) await send(etype=EventTypes.PowerLevels, content=pl_content)
else: else:
power_level_content = { power_level_content = {
"users": {creator_id: 100}, "users": {creator_id: 100},
@ -833,33 +828,33 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override: if power_level_content_override:
power_level_content.update(power_level_content_override) power_level_content.update(power_level_content_override)
yield send(etype=EventTypes.PowerLevels, content=power_level_content) await send(etype=EventTypes.PowerLevels, content=power_level_content)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
yield send( await send(
etype=EventTypes.CanonicalAlias, etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()}, content={"alias": room_alias.to_string()},
) )
if (EventTypes.JoinRules, "") not in initial_state: if (EventTypes.JoinRules, "") not in initial_state:
yield send( await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
) )
if (EventTypes.RoomHistoryVisibility, "") not in initial_state: if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
yield send( await send(
etype=EventTypes.RoomHistoryVisibility, etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]}, content={"history_visibility": config["history_visibility"]},
) )
if config["guest_can_join"]: if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state: if (EventTypes.GuestAccess, "") not in initial_state:
yield send( await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
) )
for (etype, state_key), content in initial_state.items(): for (etype, state_key), content in initial_state.items():
yield send(etype=etype, state_key=state_key, content=content) await send(etype=etype, state_key=state_key, content=content)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_room_id( def _generate_room_id(

View file

@ -142,8 +142,7 @@ class RoomMemberHandler(object):
""" """
raise NotImplementedError() raise NotImplementedError()
@defer.inlineCallbacks async def _local_membership_update(
def _local_membership_update(
self, self,
requester, requester,
target, target,
@ -164,7 +163,7 @@ class RoomMemberHandler(object):
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
event, context = yield self.event_creation_handler.create_event( event, context = await self.event_creation_handler.create_event(
requester, requester,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
@ -182,18 +181,18 @@ class RoomMemberHandler(object):
) )
# Check if this event matches the previous membership event for the user. # Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_handler.deduplicate_state_event( duplicate = await self.event_creation_handler.deduplicate_state_event(
event, context event, context
) )
if duplicate is not None: if duplicate is not None:
# Discard the new event since this membership change is a no-op. # Discard the new event since this membership change is a no-op.
return duplicate return duplicate
yield self.event_creation_handler.handle_new_client_event( await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit requester, event, context, extra_users=[target], ratelimit=ratelimit
) )
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@ -203,15 +202,15 @@ class RoomMemberHandler(object):
# info. # info.
newly_joined = True newly_joined = True
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined: if newly_joined:
yield self._user_joined_room(target, room_id) await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target, room_id) await self._user_left_room(target, room_id)
return event return event
@ -253,8 +252,7 @@ class RoomMemberHandler(object):
for tag, tag_content in room_tags.items(): for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
@defer.inlineCallbacks async def update_membership(
def update_membership(
self, self,
requester, requester,
target, target,
@ -269,8 +267,8 @@ class RoomMemberHandler(object):
): ):
key = (room_id,) key = (room_id,)
with (yield self.member_linearizer.queue(key)): with (await self.member_linearizer.queue(key)):
result = yield self._update_membership( result = await self._update_membership(
requester, requester,
target, target,
room_id, room_id,
@ -285,8 +283,7 @@ class RoomMemberHandler(object):
return result return result
@defer.inlineCallbacks async def _update_membership(
def _update_membership(
self, self,
requester, requester,
target, target,
@ -321,7 +318,7 @@ class RoomMemberHandler(object):
# if this is a join with a 3pid signature, we may need to turn a 3pid # if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join. # invite into a normal invite before we can handle the join.
if third_party_signed is not None: if third_party_signed is not None:
yield self.federation_handler.exchange_third_party_invite( await self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"], third_party_signed["sender"],
target.to_string(), target.to_string(),
room_id, room_id,
@ -332,7 +329,7 @@ class RoomMemberHandler(object):
remote_room_hosts = [] remote_room_hosts = []
if effective_membership_state not in ("leave", "ban"): if effective_membership_state not in ("leave", "ban"):
is_blocked = yield self.store.is_room_blocked(room_id) is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
@ -351,7 +348,7 @@ class RoomMemberHandler(object):
is_requester_admin = True is_requester_admin = True
else: else:
is_requester_admin = yield self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin: if not is_requester_admin:
if self.config.block_non_admin_invites: if self.config.block_non_admin_invites:
@ -370,9 +367,9 @@ class RoomMemberHandler(object):
if block_invite: if block_invite:
raise SynapseError(403, "Invites have been disabled on this server") raise SynapseError(403, "Invites have been disabled on this server")
latest_event_ids = yield self.store.get_prev_events_for_room(room_id) latest_event_ids = await self.store.get_prev_events_for_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = await self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids room_id, latest_event_ids=latest_event_ids
) )
@ -381,7 +378,7 @@ class RoomMemberHandler(object):
# transitions and generic otherwise # transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id: if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True) old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban": if action == "unban" and old_membership != "ban":
raise SynapseError( raise SynapseError(
@ -413,7 +410,7 @@ class RoomMemberHandler(object):
old_membership == Membership.INVITE old_membership == Membership.INVITE
and effective_membership_state == Membership.LEAVE and effective_membership_state == Membership.LEAVE
): ):
is_blocked = yield self._is_server_notice_room(room_id) is_blocked = await self._is_server_notice_room(room_id)
if is_blocked: if is_blocked:
raise SynapseError( raise SynapseError(
http_client.FORBIDDEN, http_client.FORBIDDEN,
@ -424,18 +421,18 @@ class RoomMemberHandler(object):
if action == "kick": if action == "kick":
raise AuthError(403, "The target user is not in the room") raise AuthError(403, "The target user is not in the room")
is_host_in_room = yield self._is_host_in_room(current_state_ids) is_host_in_room = await self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN: if effective_membership_state == Membership.JOIN:
if requester.is_guest: if requester.is_guest:
guest_can_join = yield self._can_guest_join(current_state_ids) guest_can_join = await self._can_guest_join(current_state_ids)
if not guest_can_join: if not guest_can_join:
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if not is_host_in_room: if not is_host_in_room:
inviter = yield self._get_inviter(target.to_string(), room_id) inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)
@ -443,13 +440,13 @@ class RoomMemberHandler(object):
profile = self.profile_handler profile = self.profile_handler
if not content_specified: if not content_specified:
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = await profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target) content["avatar_url"] = await profile.get_avatar_url(target)
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
remote_join_response = yield self._remote_join( remote_join_response = await self._remote_join(
requester, remote_room_hosts, room_id, target, content requester, remote_room_hosts, room_id, target, content
) )
@ -458,7 +455,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE: elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room: if not is_host_in_room:
# perhaps we've been invited # perhaps we've been invited
inviter = yield self._get_inviter(target.to_string(), room_id) inviter = await self._get_inviter(target.to_string(), room_id)
if not inviter: if not inviter:
raise SynapseError(404, "Not a known room") raise SynapseError(404, "Not a known room")
@ -472,12 +469,12 @@ class RoomMemberHandler(object):
else: else:
# send the rejection to the inviter's HS. # send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain] remote_room_hosts = remote_room_hosts + [inviter.domain]
res = yield self._remote_reject_invite( res = await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content, requester, remote_room_hosts, room_id, target, content,
) )
return res return res
res = yield self._local_membership_update( res = await self._local_membership_update(
requester=requester, requester=requester,
target=target, target=target,
room_id=room_id, room_id=room_id,
@ -572,8 +569,7 @@ class RoomMemberHandler(object):
) )
continue continue
@defer.inlineCallbacks async def send_membership_event(self, requester, event, context, ratelimit=True):
def send_membership_event(self, requester, event, context, ratelimit=True):
""" """
Change the membership status of a user in a room. Change the membership status of a user in a room.
@ -599,27 +595,27 @@ class RoomMemberHandler(object):
else: else:
requester = types.create_requester(target_user) requester = types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event( prev_event = await self.event_creation_handler.deduplicate_state_event(
event, context event, context
) )
if prev_event is not None: if prev_event is not None:
return return
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if requester.is_guest: if requester.is_guest:
guest_can_join = yield self._can_guest_join(prev_state_ids) guest_can_join = await self._can_guest_join(prev_state_ids)
if not guest_can_join: if not guest_can_join:
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if event.membership not in (Membership.LEAVE, Membership.BAN): if event.membership not in (Membership.LEAVE, Membership.BAN):
is_blocked = yield self.store.is_room_blocked(room_id) is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_handler.handle_new_client_event( await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target_user], ratelimit=ratelimit requester, event, context, extra_users=[target_user], ratelimit=ratelimit
) )
@ -633,15 +629,15 @@ class RoomMemberHandler(object):
# info. # info.
newly_joined = True newly_joined = True
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined: if newly_joined:
yield self._user_joined_room(target_user, room_id) await self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target_user, room_id) await self._user_left_room(target_user, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _can_guest_join(self, current_state_ids): def _can_guest_join(self, current_state_ids):
@ -699,8 +695,7 @@ class RoomMemberHandler(object):
if invite: if invite:
return UserID.from_string(invite.sender) return UserID.from_string(invite.sender)
@defer.inlineCallbacks async def do_3pid_invite(
def do_3pid_invite(
self, self,
room_id, room_id,
inviter, inviter,
@ -712,7 +707,7 @@ class RoomMemberHandler(object):
id_access_token=None, id_access_token=None,
): ):
if self.config.block_non_admin_invites: if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin: if not is_requester_admin:
raise SynapseError( raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN 403, "Invites have been disabled on this server", Codes.FORBIDDEN
@ -720,9 +715,9 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we # We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events. # can't just rely on the standard ratelimiting of events.
yield self.base_handler.ratelimit(requester) await self.base_handler.ratelimit(requester)
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id medium, address, room_id
) )
if not can_invite: if not can_invite:
@ -737,16 +732,16 @@ class RoomMemberHandler(object):
403, "Looking up third-party identifiers is denied from this server" 403, "Looking up third-party identifiers is denied from this server"
) )
invitee = yield self.identity_handler.lookup_3pid( invitee = await self.identity_handler.lookup_3pid(
id_server, medium, address, id_access_token id_server, medium, address, id_access_token
) )
if invitee: if invitee:
yield self.update_membership( await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
) )
else: else:
yield self._make_and_store_3pid_invite( await self._make_and_store_3pid_invite(
requester, requester,
id_server, id_server,
medium, medium,
@ -757,8 +752,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token, id_access_token=id_access_token,
) )
@defer.inlineCallbacks async def _make_and_store_3pid_invite(
def _make_and_store_3pid_invite(
self, self,
requester, requester,
id_server, id_server,
@ -769,7 +763,7 @@ class RoomMemberHandler(object):
txn_id, txn_id,
id_access_token=None, id_access_token=None,
): ):
room_state = yield self.state_handler.get_current_state(room_id) room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = "" inviter_display_name = ""
inviter_avatar_url = "" inviter_avatar_url = ""
@ -807,7 +801,7 @@ class RoomMemberHandler(object):
public_keys, public_keys,
fallback_public_key, fallback_public_key,
display_name, display_name,
) = yield self.identity_handler.ask_id_server_for_third_party_invite( ) = await self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester, requester=requester,
id_server=id_server, id_server=id_server,
medium=medium, medium=medium,
@ -823,7 +817,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token, id_access_token=id_access_token,
) )
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,
@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity return complexity["v1"] > max_complexity
@defer.inlineCallbacks async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join
""" """
# filter ourselves out of remote_room_hosts: do_invite_join ignores it # filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled: if self.hs.config.limit_remote_rooms.enabled:
# Fetch the room complexity # Fetch the room complexity
too_complex = yield self._is_remote_room_too_complex( too_complex = await self._is_remote_room_too_complex(
room_id, remote_room_hosts room_id, remote_room_hosts
) )
if too_complex is True: if too_complex is True:
@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking # join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
# need to do the invite/join dance. # need to do the invite/join dance.
yield defer.ensureDeferred( await self.federation_handler.do_invite_join(
self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content remote_room_hosts, room_id, user.to_string(), content
) )
) await self._user_joined_room(user, room_id)
yield self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the # Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before. # complexity of it before.
@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return return
# Check again, but with the local state events # Check again, but with the local state events
too_complex = yield self._is_local_room_too_complex(room_id) too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False: if too_complex is False:
# We're under the limit. # We're under the limit.
@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# The room is too large. Leave. # The room is too large. Leave.
requester = types.create_requester(user, None, False, None) requester = types.create_requester(user, None, False, None)
yield self.update_membership( await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave" requester=requester, target=user, room_id=room_id, action="leave"
) )
raise SynapseError( raise SynapseError(
@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
def _user_joined_room(self, target, room_id): def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room """Implements RoomMemberHandler._user_joined_room
""" """
return user_joined_room(self.distributor, target, room_id) return defer.succeed(user_joined_room(self.distributor, target, room_id))
def _user_left_room(self, target, room_id): def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room """Implements RoomMemberHandler._user_left_room
""" """
return user_left_room(self.distributor, target, room_id) return defer.succeed(user_left_room(self.distributor, target, room_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def forget(self, user, room_id): def forget(self, user, room_id):

View file

@ -16,8 +16,6 @@ import logging
from six import iteritems, string_types from six import iteritems, string_types
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.config import ConfigError from synapse.config import ConfigError
@ -59,8 +57,7 @@ class ConsentServerNotices(object):
self._consent_uri_builder = ConsentURIBuilder(hs.config) self._consent_uri_builder = ConsentURIBuilder(hs.config)
@defer.inlineCallbacks async def maybe_send_server_notice_to_user(self, user_id):
def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, and does so if so """Check if we need to send a notice to this user, and does so if so
Args: Args:
@ -78,7 +75,7 @@ class ConsentServerNotices(object):
return return
self._users_in_progress.add(user_id) self._users_in_progress.add(user_id)
try: try:
u = yield self._store.get_user_by_id(user_id) u = await self._store.get_user_by_id(user_id)
if u["is_guest"] and not self._send_to_guests: if u["is_guest"] and not self._send_to_guests:
# don't send to guests # don't send to guests
@ -100,8 +97,8 @@ class ConsentServerNotices(object):
content = copy_with_str_subst( content = copy_with_str_subst(
self._server_notice_content, {"consent_uri": consent_uri} self._server_notice_content, {"consent_uri": consent_uri}
) )
yield self._server_notices_manager.send_notice(user_id, content) await self._server_notices_manager.send_notice(user_id, content)
yield self._store.user_set_consent_server_notice_sent( await self._store.user_set_consent_server_notice_sent(
user_id, self._current_consent_version user_id, self._current_consent_version
) )
except SynapseError as e: except SynapseError as e:

View file

@ -50,8 +50,7 @@ class ResourceLimitsServerNotices(object):
self._notifier = hs.get_notifier() self._notifier = hs.get_notifier()
@defer.inlineCallbacks async def maybe_send_server_notice_to_user(self, user_id):
def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, this will be true in """Check if we need to send a notice to this user, this will be true in
two cases. two cases.
1. The server has reached its limit does not reflect this 1. The server has reached its limit does not reflect this
@ -74,13 +73,13 @@ class ResourceLimitsServerNotices(object):
# Don't try and send server notices unless they've been enabled # Don't try and send server notices unless they've been enabled
return return
timestamp = yield self._store.user_last_seen_monthly_active(user_id) timestamp = await self._store.user_last_seen_monthly_active(user_id)
if timestamp is None: if timestamp is None:
# This user will be blocked from receiving the notice anyway. # This user will be blocked from receiving the notice anyway.
# In practice, not sure we can ever get here # In practice, not sure we can ever get here
return return
room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user( room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
user_id user_id
) )
@ -88,10 +87,10 @@ class ResourceLimitsServerNotices(object):
logger.warning("Failed to get server notices room") logger.warning("Failed to get server notices room")
return return
yield self._check_and_set_tags(user_id, room_id) await self._check_and_set_tags(user_id, room_id)
# Determine current state of room # Determine current state of room
currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
limit_msg = None limit_msg = None
limit_type = None limit_type = None
@ -99,7 +98,7 @@ class ResourceLimitsServerNotices(object):
# Normally should always pass in user_id to check_auth_blocking # Normally should always pass in user_id to check_auth_blocking
# if you have it, but in this case are checking what would happen # if you have it, but in this case are checking what would happen
# to other users if they were to arrive. # to other users if they were to arrive.
yield self._auth.check_auth_blocking() await self._auth.check_auth_blocking()
except ResourceLimitError as e: except ResourceLimitError as e:
limit_msg = e.msg limit_msg = e.msg
limit_type = e.limit_type limit_type = e.limit_type
@ -112,22 +111,21 @@ class ResourceLimitsServerNotices(object):
# We have hit the MAU limit, but MAU alerting is disabled: # We have hit the MAU limit, but MAU alerting is disabled:
# reset room if necessary and return # reset room if necessary and return
if currently_blocked: if currently_blocked:
self._remove_limit_block_notification(user_id, ref_events) await self._remove_limit_block_notification(user_id, ref_events)
return return
if currently_blocked and not limit_msg: if currently_blocked and not limit_msg:
# Room is notifying of a block, when it ought not to be. # Room is notifying of a block, when it ought not to be.
yield self._remove_limit_block_notification(user_id, ref_events) await self._remove_limit_block_notification(user_id, ref_events)
elif not currently_blocked and limit_msg: elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be. # Room is not notifying of a block, when it ought to be.
yield self._apply_limit_block_notification( await self._apply_limit_block_notification(
user_id, limit_msg, limit_type user_id, limit_msg, limit_type
) )
except SynapseError as e: except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e) logger.error("Error sending resource limits server notice: %s", e)
@defer.inlineCallbacks async def _remove_limit_block_notification(self, user_id, ref_events):
def _remove_limit_block_notification(self, user_id, ref_events):
"""Utility method to remove limit block notifications from the server """Utility method to remove limit block notifications from the server
notices room. notices room.
@ -137,12 +135,13 @@ class ResourceLimitsServerNotices(object):
limit blocking and need to be preserved. limit blocking and need to be preserved.
""" """
content = {"pinned": ref_events} content = {"pinned": ref_events}
yield self._server_notices_manager.send_notice( await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, "" user_id, content, EventTypes.Pinned, ""
) )
@defer.inlineCallbacks async def _apply_limit_block_notification(
def _apply_limit_block_notification(self, user_id, event_body, event_limit_type): self, user_id, event_body, event_limit_type
):
"""Utility method to apply limit block notifications in the server """Utility method to apply limit block notifications in the server
notices room. notices room.
@ -159,12 +158,12 @@ class ResourceLimitsServerNotices(object):
"admin_contact": self._config.admin_contact, "admin_contact": self._config.admin_contact,
"limit_type": event_limit_type, "limit_type": event_limit_type,
} }
event = yield self._server_notices_manager.send_notice( event = await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Message user_id, content, EventTypes.Message
) )
content = {"pinned": [event.event_id]} content = {"pinned": [event.event_id]}
yield self._server_notices_manager.send_notice( await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, "" user_id, content, EventTypes.Pinned, ""
) )
@ -198,7 +197,7 @@ class ResourceLimitsServerNotices(object):
room_id(str): The room id of the server notices room room_id(str): The room id of the server notices room
Returns: Returns:
Deferred[Tuple[bool, List]]:
bool: Is the room currently blocked bool: Is the room currently blocked
list: The list of pinned events that are unrelated to limit blocking list: The list of pinned events that are unrelated to limit blocking
This list can be used as a convenience in the case where the block This list can be used as a convenience in the case where the block

View file

@ -14,11 +14,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,8 +49,7 @@ class ServerNoticesManager(object):
""" """
return self._config.server_notices_mxid is not None return self._config.server_notices_mxid is not None
@defer.inlineCallbacks async def send_notice(
def send_notice(
self, user_id, event_content, type=EventTypes.Message, state_key=None self, user_id, event_content, type=EventTypes.Message, state_key=None
): ):
"""Send a notice to the given user """Send a notice to the given user
@ -68,8 +65,8 @@ class ServerNoticesManager(object):
Returns: Returns:
Deferred[FrozenEvent] Deferred[FrozenEvent]
""" """
room_id = yield self.get_or_create_notice_room_for_user(user_id) room_id = await self.get_or_create_notice_room_for_user(user_id)
yield self.maybe_invite_user_to_room(user_id, room_id) await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid system_mxid = self._config.server_notices_mxid
requester = create_requester(system_mxid) requester = create_requester(system_mxid)
@ -86,13 +83,13 @@ class ServerNoticesManager(object):
if state_key is not None: if state_key is not None:
event_dict["state_key"] = state_key event_dict["state_key"] = state_key
res = yield self._event_creation_handler.create_and_send_nonmember_event( res = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False requester, event_dict, ratelimit=False
) )
return res return res
@cachedInlineCallbacks() @cached()
def get_or_create_notice_room_for_user(self, user_id): async def get_or_create_notice_room_for_user(self, user_id):
"""Get the room for notices for a given user """Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't If we have not yet created a notice room for this user, create it, but don't
@ -109,7 +106,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users" assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
rooms = yield self._store.get_rooms_for_local_user_where_membership_is( rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN] user_id, [Membership.INVITE, Membership.JOIN]
) )
for room in rooms: for room in rooms:
@ -118,7 +115,7 @@ class ServerNoticesManager(object):
# be joined. This is kinda deliberate, in that if somebody somehow # be joined. This is kinda deliberate, in that if somebody somehow
# manages to invite the system user to a room, that doesn't make it # manages to invite the system user to a room, that doesn't make it
# the server notices room. # the server notices room.
user_ids = yield self._store.get_users_in_room(room.room_id) user_ids = await self._store.get_users_in_room(room.room_id)
if self.server_notices_mxid in user_ids: if self.server_notices_mxid in user_ids:
# we found a room which our user shares with the system notice # we found a room which our user shares with the system notice
# user # user
@ -146,7 +143,7 @@ class ServerNoticesManager(object):
} }
requester = create_requester(self.server_notices_mxid) requester = create_requester(self.server_notices_mxid)
info = yield self._room_creation_handler.create_room( info = await self._room_creation_handler.create_room(
requester, requester,
config={ config={
"preset": RoomCreationPreset.PRIVATE_CHAT, "preset": RoomCreationPreset.PRIVATE_CHAT,
@ -158,7 +155,7 @@ class ServerNoticesManager(object):
) )
room_id = info["room_id"] room_id = info["room_id"]
max_id = yield self._store.add_tag_to_room( max_id = await self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
) )
self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@ -166,8 +163,7 @@ class ServerNoticesManager(object):
logger.info("Created server notices room %s for %s", room_id, user_id) logger.info("Created server notices room %s for %s", room_id, user_id)
return room_id return room_id
@defer.inlineCallbacks async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
def maybe_invite_user_to_room(self, user_id: str, room_id: str):
"""Invite the given user to the given server room, unless the user has already """Invite the given user to the given server room, unless the user has already
joined or been invited to it. joined or been invited to it.
@ -179,14 +175,14 @@ class ServerNoticesManager(object):
# Check whether the user has already joined or been invited to this room. If # Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them. # that's the case, there is no need to re-invite them.
joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is( joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN] user_id, [Membership.INVITE, Membership.JOIN]
) )
for room in joined_rooms: for room in joined_rooms:
if room.room_id == room_id: if room.room_id == room_id:
return return
yield self._room_member_handler.update_membership( await self._room_member_handler.update_membership(
requester=requester, requester=requester,
target=UserID.from_string(user_id), target=UserID.from_string(user_id),
room_id=room_id, room_id=room_id,

View file

@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.consent_server_notices import ConsentServerNotices
from synapse.server_notices.resource_limits_server_notices import ( from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices, ResourceLimitsServerNotices,
@ -36,18 +34,16 @@ class ServerNoticesSender(object):
ResourceLimitsServerNotices(hs), ResourceLimitsServerNotices(hs),
) )
@defer.inlineCallbacks async def on_user_syncing(self, user_id):
def on_user_syncing(self, user_id):
"""Called when the user performs a sync operation. """Called when the user performs a sync operation.
Args: Args:
user_id (str): mxid of user who synced user_id (str): mxid of user who synced
""" """
for sn in self._server_notices: for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id) await sn.maybe_send_server_notice_to_user(user_id)
@defer.inlineCallbacks async def on_user_ip(self, user_id):
def on_user_ip(self, user_id):
"""Called on the master when a worker process saw a client request. """Called on the master when a worker process saw a client request.
Args: Args:
@ -57,4 +53,4 @@ class ServerNoticesSender(object):
# we check for notices to send to the user in on_user_ip as well as # we check for notices to send to the user in on_user_ip as well as
# in on_user_syncing # in on_user_syncing
for sn in self._server_notices: for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id) await sn.maybe_send_server_notice_to_user(user_id)

View file

@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user", desc="delete_account_validity_for_user",
) )
@defer.inlineCallbacks async def is_server_admin(self, user):
def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver. """Determines if a user is an admin of this homeserver.
Args: Args:
@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool): Returns (bool):
true iff the user is a server admin, false otherwise. true iff the user is a server admin, false otherwise.
""" """
res = yield self.db.simple_select_one_onecol( res = await self.db.simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user.to_string()}, keyvalues={"name": user.to_string()},
retcol="admin", retcol="admin",

View file

@ -82,19 +82,27 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name(self): def test_set_my_name(self):
yield self.handler.set_displayname( yield defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr." self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
) )
)
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), (
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank Jr.", "Frank Jr.",
) )
# Set displayname again # Set displayname again
yield self.handler.set_displayname( yield defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank" self.frank, synapse.types.create_requester(self.frank), "Frank"
) )
)
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
@ -112,17 +120,21 @@ class ProfileTestCase(unittest.TestCase):
) )
# Setting displayname a second time is forbidden # Setting displayname a second time is forbidden
d = self.handler.set_displayname( d = defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr." self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
) )
)
yield self.assertFailure(d, SynapseError) yield self.assertFailure(d, SynapseError)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self):
d = self.handler.set_displayname( d = defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr." self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
) )
)
yield self.assertFailure(d, AuthError) yield self.assertFailure(d, AuthError)
@ -165,11 +177,13 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_avatar(self): def test_set_my_avatar(self):
yield self.handler.set_avatar_url( yield defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank, self.frank,
synapse.types.create_requester(self.frank), synapse.types.create_requester(self.frank),
"http://my.server/pic.gif", "http://my.server/pic.gif",
) )
)
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)), (yield self.store.get_profile_avatar_url(self.frank.localpart)),
@ -177,11 +191,13 @@ class ProfileTestCase(unittest.TestCase):
) )
# Set avatar again # Set avatar again
yield self.handler.set_avatar_url( yield defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank, self.frank,
synapse.types.create_requester(self.frank), synapse.types.create_requester(self.frank),
"http://my.server/me.png", "http://my.server/me.png",
) )
)
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)), (yield self.store.get_profile_avatar_url(self.frank.localpart)),
@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase):
) )
# Set avatar a second time is forbidden # Set avatar a second time is forbidden
d = self.handler.set_avatar_url( d = defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank, self.frank,
synapse.types.create_requester(self.frank), synapse.types.create_requester(self.frank),
"http://my.server/pic.gif", "http://my.server/pic.gif",
) )
)
yield self.assertFailure(d, SynapseError) yield self.assertFailure(d, SynapseError)

View file

@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_real_user = Mock(return_value=False) self.store.is_real_user = Mock(return_value=defer.succeed(False))
user_id = self.get_success(self.handler.register_user(localpart="support")) user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=1) self.store.count_real_users = Mock(return_value=defer.succeed(1))
self.store.is_real_user = Mock(return_value=True) self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=2) self.store.count_real_users = Mock(return_value=defer.succeed(2))
self.store.is_real_user = Mock(return_value=True) self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError self.handler.register_user(localpart=invalid_user_id), SynapseError
) )
@defer.inlineCallbacks async def get_or_create_user(
def get_or_create_user(self, requester, localpart, displayname, password_hash=None): self, requester, localpart, displayname, password_hash=None
):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
yield self.hs.get_auth().check_auth_blocking() await self.hs.get_auth().check_auth_blocking()
need_register = True need_register = True
try: try:
yield self.handler.check_username(localpart) await self.handler.check_username(localpart)
except SynapseError as e: except SynapseError as e:
if e.errcode == Codes.USER_IN_USE: if e.errcode == Codes.USER_IN_USE:
need_register = False need_register = False
@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
token = self.macaroon_generator.generate_access_token(user_id) token = self.macaroon_generator.generate_access_token(user_id)
if need_register: if need_register:
yield self.handler.register_with_store( await self.handler.register_with_store(
user_id=user_id, user_id=user_id,
password_hash=password_hash, password_hash=password_hash,
create_profile_with_displayname=user.localpart, create_profile_with_displayname=user.localpart,
) )
else: else:
yield defer.ensureDeferred( await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
)
yield self.store.add_access_token_to_user( await self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None user_id=user_id, token=token, device_id=None, valid_until_ms=None
) )
if displayname is not None: if displayname is not None:
# logger.info("setting user display name: %s -> %s", user_id, displayname) # logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.hs.get_profile_handler().set_displayname( await self.hs.get_profile_handler().set_displayname(
user, requester, displayname, by_admin=True user, requester, displayname, by_admin=True
) )

View file

@ -55,25 +55,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000) return_value=defer.succeed(1000)
) )
self._send_notice = self._rlsn._server_notices_manager.send_notice self._rlsn._server_notices_manager.send_notice = Mock(
self._rlsn._server_notices_manager.send_notice = Mock() return_value=defer.succeed(Mock())
self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None)) )
self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
self._send_notice = self._rlsn._server_notices_manager.send_notice self._send_notice = self._rlsn._server_notices_manager.send_notice
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test" self.user_id = "@user_id:test"
# self.server_notices_mxid = "@server:test"
# self.server_notices_mxid_display_name = None
# self.server_notices_mxid_avatar_url = None
# self.server_notices_room_name = "Server Notices"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
returnValue="" return_value=defer.succeed("!something:localhost")
) )
self._rlsn._store.add_tag_to_room = Mock() self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value={}) self._rlsn._store.get_tags_for_room = Mock(return_value={})
self.hs.config.admin_contact = "mailto:user@test.com" self.hs.config.admin_contact = "mailto:user@test.com"
@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed""" """Test when user has blocked notice, but should have it removed"""
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event}) return_value=defer.succeed({"123": mock_event})
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event # Would be better to check the content, but once == remove blocking event
self._send_notice.assert_called_once() self._send_notice.assert_called_once()
@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP) Test when user has blocked notice, but notice ought to be there (NOOP)
""" """
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo") return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
) )
mock_event = Mock( mock_event = Mock(
@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event}) return_value=defer.succeed({"123": mock_event})
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user does not have blocked notice, but should have one Test when user does not have blocked notice, but should have one
""" """
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo") return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user does not have blocked notice, nor should they (NOOP) Test when user does not have blocked notice, nor should they (NOOP)
""" """
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever Test when user is not part of the MAU cohort - this should not ever
happen - but ... happen - but ...
""" """
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None) return_value=defer.succeed(None)
) )
@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room an alert message is not sent into the room
""" """
self.hs.config.mau_limit_alerting = False self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
) ),
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self.assertTrue(self._send_notice.call_count == 0) self.assertEqual(self._send_notice.call_count, 0)
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
""" """
Test that when a server is disabled, that MAU limit alerting is ignored. Test that when a server is disabled, that MAU limit alerting is ignored.
""" """
self.hs.config.mau_limit_alerting = False self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
) ),
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
self.hs.config.mau_limit_alerting = False self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
) )
)
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
return_value=defer.succeed((True, [])) return_value=defer.succeed((True, []))
) )
@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000) self.store.get_monthly_active_count = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(return_value=1000) self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
# Call the function multiple times to ensure we only send the notice once # Call the function multiple times to ensure we only send the notice once
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))

View file

@ -27,9 +27,11 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test") user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None) our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler() room_creator = self.homeserver.get_room_creation_handler()
room = room_creator.create_room( room = ensureDeferred(
room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
) )
)
self.reactor.advance(0.1) self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"] self.room_id = self.successResultOf(room)["room_id"]