Port some admin handlers to async/await (#6559)

This commit is contained in:
Erik Johnston 2019-12-19 15:07:28 +00:00 committed by GitHub
parent bca30cefee
commit 3d46124ad0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 56 deletions

1
changelog.d/6559.misc Normal file
View file

@ -0,0 +1 @@
Port `synapse.handlers.admin` and `synapse.handlers.deactivate_account` to async/await.

View file

@ -104,8 +104,10 @@ def export_data_command(hs, args):
user_id = args.user_id user_id = args.user_id
directory = args.output_directory directory = args.output_directory
res = yield hs.get_handlers().admin_handler.export_user_data( res = yield defer.ensureDeferred(
user_id, FileExfiltrationWriter(user_id, directory=directory) hs.get_handlers().admin_handler.export_user_data(
user_id, FileExfiltrationWriter(user_id, directory=directory)
)
) )
print(res) print(res)

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -33,11 +31,10 @@ class AdminHandler(BaseHandler):
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
@defer.inlineCallbacks async def get_whois(self, user):
def get_whois(self, user):
connections = [] connections = []
sessions = yield self.store.get_user_ip_and_agents(user) sessions = await self.store.get_user_ip_and_agents(user)
for session in sessions: for session in sessions:
connections.append( connections.append(
{ {
@ -54,20 +51,18 @@ class AdminHandler(BaseHandler):
return ret return ret
@defer.inlineCallbacks async def get_users(self):
def get_users(self):
"""Function to retrieve a list of users in users table. """Function to retrieve a list of users in users table.
Args: Args:
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
ret = yield self.store.get_users() ret = await self.store.get_users()
return ret return ret
@defer.inlineCallbacks async def get_users_paginate(self, start, limit, name, guests, deactivated):
def get_users_paginate(self, start, limit, name, guests, deactivated):
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from
users list. This will return a json list of users. users list. This will return a json list of users.
@ -80,14 +75,13 @@ class AdminHandler(BaseHandler):
Returns: Returns:
defer.Deferred: resolves to json list[dict[str, Any]] defer.Deferred: resolves to json list[dict[str, Any]]
""" """
ret = yield self.store.get_users_paginate( ret = await self.store.get_users_paginate(
start, limit, name, guests, deactivated start, limit, name, guests, deactivated
) )
return ret return ret
@defer.inlineCallbacks async def search_users(self, term):
def search_users(self, term):
"""Function to search users list for one or more users with """Function to search users list for one or more users with
the matched term. the matched term.
@ -96,7 +90,7 @@ class AdminHandler(BaseHandler):
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
ret = yield self.store.search_users(term) ret = await self.store.search_users(term)
return ret return ret
@ -119,8 +113,7 @@ class AdminHandler(BaseHandler):
""" """
return self.store.set_server_admin(user, admin) return self.store.set_server_admin(user, admin)
@defer.inlineCallbacks async def export_user_data(self, user_id, writer):
def export_user_data(self, user_id, writer):
"""Write all data we have on the user to the given writer. """Write all data we have on the user to the given writer.
Args: Args:
@ -132,7 +125,7 @@ class AdminHandler(BaseHandler):
The returned value is that returned by `writer.finished()`. The returned value is that returned by `writer.finished()`.
""" """
# Get all rooms the user is in or has been in # Get all rooms the user is in or has been in
rooms = yield self.store.get_rooms_for_user_where_membership_is( rooms = await self.store.get_rooms_for_user_where_membership_is(
user_id, user_id,
membership_list=( membership_list=(
Membership.JOIN, Membership.JOIN,
@ -145,7 +138,7 @@ class AdminHandler(BaseHandler):
# We only try and fetch events for rooms the user has been in. If # We only try and fetch events for rooms the user has been in. If
# they've been e.g. invited to a room without joining then we handle # they've been e.g. invited to a room without joining then we handle
# those seperately. # those seperately.
rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id) rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
for index, room in enumerate(rooms): for index, room in enumerate(rooms):
room_id = room.room_id room_id = room.room_id
@ -154,7 +147,7 @@ class AdminHandler(BaseHandler):
"[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
) )
forgotten = yield self.store.did_forget(user_id, room_id) forgotten = await self.store.did_forget(user_id, room_id)
if forgotten: if forgotten:
logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
continue continue
@ -166,7 +159,7 @@ class AdminHandler(BaseHandler):
if room.membership == Membership.INVITE: if room.membership == Membership.INVITE:
event_id = room.event_id event_id = room.event_id
invite = yield self.store.get_event(event_id, allow_none=True) invite = await self.store.get_event(event_id, allow_none=True)
if invite: if invite:
invited_state = invite.unsigned["invite_room_state"] invited_state = invite.unsigned["invite_room_state"]
writer.write_invite(room_id, invite, invited_state) writer.write_invite(room_id, invite, invited_state)
@ -177,7 +170,7 @@ class AdminHandler(BaseHandler):
# were joined. We estimate that point by looking at the # were joined. We estimate that point by looking at the
# stream_ordering of the last membership if it wasn't a join. # stream_ordering of the last membership if it wasn't a join.
if room.membership == Membership.JOIN: if room.membership == Membership.JOIN:
stream_ordering = yield self.store.get_room_max_stream_ordering() stream_ordering = self.store.get_room_max_stream_ordering()
else: else:
stream_ordering = room.stream_ordering stream_ordering = room.stream_ordering
@ -203,7 +196,7 @@ class AdminHandler(BaseHandler):
# events that we have and then filtering, this isn't the most # events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything. # efficient method perhaps but it does guarantee we get everything.
while True: while True:
events, _ = yield self.store.paginate_room_events( events, _ = await self.store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction="f" room_id, from_key, to_key, limit=100, direction="f"
) )
if not events: if not events:
@ -211,7 +204,7 @@ class AdminHandler(BaseHandler):
from_key = events[-1].internal_metadata.after from_key = events[-1].internal_metadata.after
events = yield filter_events_for_client(self.storage, user_id, events) events = await filter_events_for_client(self.storage, user_id, events)
writer.write_events(room_id, events) writer.write_events(room_id, events)
@ -247,7 +240,7 @@ class AdminHandler(BaseHandler):
for event_id in extremities: for event_id in extremities:
if not event_to_unseen_prevs[event_id]: if not event_to_unseen_prevs[event_id]:
continue continue
state = yield self.state_store.get_state_for_event(event_id) state = await self.state_store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state) writer.write_state(room_id, event_id, state)
return writer.finished() return writer.finished()

View file

@ -15,8 +15,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -46,8 +44,7 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity.enabled self._account_validity_enabled = hs.config.account_validity.enabled
@defer.inlineCallbacks async def deactivate_account(self, user_id, erase_data, id_server=None):
def deactivate_account(self, user_id, erase_data, id_server=None):
"""Deactivate a user's account """Deactivate a user's account
Args: Args:
@ -74,11 +71,11 @@ class DeactivateAccountHandler(BaseHandler):
identity_server_supports_unbinding = True identity_server_supports_unbinding = True
# Retrieve the 3PIDs this user has bound to an identity server # Retrieve the 3PIDs this user has bound to an identity server
threepids = yield self.store.user_get_bound_threepids(user_id) threepids = await self.store.user_get_bound_threepids(user_id)
for threepid in threepids: for threepid in threepids:
try: try:
result = yield self._identity_handler.try_unbind_threepid( result = await self._identity_handler.try_unbind_threepid(
user_id, user_id,
{ {
"medium": threepid["medium"], "medium": threepid["medium"],
@ -91,33 +88,33 @@ class DeactivateAccountHandler(BaseHandler):
# Do we want this to be a fatal error or should we carry on? # Do we want this to be a fatal error or should we carry on?
logger.exception("Failed to remove threepid from ID server") logger.exception("Failed to remove threepid from ID server")
raise SynapseError(400, "Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server")
yield self.store.user_delete_threepid( await self.store.user_delete_threepid(
user_id, threepid["medium"], threepid["address"] user_id, threepid["medium"], threepid["address"]
) )
# Remove all 3PIDs this user has bound to the homeserver # Remove all 3PIDs this user has bound to the homeserver
yield self.store.user_delete_threepids(user_id) await self.store.user_delete_threepids(user_id)
# delete any devices belonging to the user, which will also # delete any devices belonging to the user, which will also
# delete corresponding access tokens. # delete corresponding access tokens.
yield self._device_handler.delete_all_devices_for_user(user_id) await self._device_handler.delete_all_devices_for_user(user_id)
# then delete any remaining access tokens which weren't associated with # then delete any remaining access tokens which weren't associated with
# a device. # a device.
yield self._auth_handler.delete_access_tokens_for_user(user_id) await self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.user_set_password_hash(user_id, None) await self.store.user_set_password_hash(user_id, None)
# Add the user to a table of users pending deactivation (ie. # Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of) # removal from all the rooms they're a member of)
yield self.store.add_user_pending_deactivation(user_id) await self.store.add_user_pending_deactivation(user_id)
# delete from user directory # delete from user directory
yield self.user_directory_handler.handle_user_deactivated(user_id) await self.user_directory_handler.handle_user_deactivated(user_id)
# Mark the user as erased, if they asked for that # Mark the user as erased, if they asked for that
if erase_data: if erase_data:
logger.info("Marking %s as erased", user_id) logger.info("Marking %s as erased", user_id)
yield self.store.mark_user_erased(user_id) await self.store.mark_user_erased(user_id)
# Now start the process that goes through that list and # Now start the process that goes through that list and
# parts users from rooms (if it isn't already running) # parts users from rooms (if it isn't already running)
@ -125,30 +122,29 @@ class DeactivateAccountHandler(BaseHandler):
# Reject all pending invites for the user, so that the user doesn't show up in the # Reject all pending invites for the user, so that the user doesn't show up in the
# "invited" section of rooms' members list. # "invited" section of rooms' members list.
yield self._reject_pending_invites_for_user(user_id) await self._reject_pending_invites_for_user(user_id)
# Remove all information on the user from the account_validity table. # Remove all information on the user from the account_validity table.
if self._account_validity_enabled: if self._account_validity_enabled:
yield self.store.delete_account_validity_for_user(user_id) await self.store.delete_account_validity_for_user(user_id)
# Mark the user as deactivated. # Mark the user as deactivated.
yield self.store.set_user_deactivated_status(user_id, True) await self.store.set_user_deactivated_status(user_id, True)
return identity_server_supports_unbinding return identity_server_supports_unbinding
@defer.inlineCallbacks async def _reject_pending_invites_for_user(self, user_id):
def _reject_pending_invites_for_user(self, user_id):
"""Reject pending invites addressed to a given user ID. """Reject pending invites addressed to a given user ID.
Args: Args:
user_id (str): The user ID to reject pending invites for. user_id (str): The user ID to reject pending invites for.
""" """
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
pending_invites = yield self.store.get_invited_rooms_for_user(user_id) pending_invites = await self.store.get_invited_rooms_for_user(user_id)
for room in pending_invites: for room in pending_invites:
try: try:
yield self._room_member_handler.update_membership( await self._room_member_handler.update_membership(
create_requester(user), create_requester(user),
user, user,
room.room_id, room.room_id,
@ -180,8 +176,7 @@ class DeactivateAccountHandler(BaseHandler):
if not self._user_parter_running: if not self._user_parter_running:
run_as_background_process("user_parter_loop", self._user_parter_loop) run_as_background_process("user_parter_loop", self._user_parter_loop)
@defer.inlineCallbacks async def _user_parter_loop(self):
def _user_parter_loop(self):
"""Loop that parts deactivated users from rooms """Loop that parts deactivated users from rooms
Returns: Returns:
@ -191,19 +186,18 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("Starting user parter") logger.info("Starting user parter")
try: try:
while True: while True:
user_id = yield self.store.get_user_pending_deactivation() user_id = await self.store.get_user_pending_deactivation()
if user_id is None: if user_id is None:
break break
logger.info("User parter parting %r", user_id) logger.info("User parter parting %r", user_id)
yield self._part_user(user_id) await self._part_user(user_id)
yield self.store.del_user_pending_deactivation(user_id) await self.store.del_user_pending_deactivation(user_id)
logger.info("User parter finished parting %r", user_id) logger.info("User parter finished parting %r", user_id)
logger.info("User parter finished: stopping") logger.info("User parter finished: stopping")
finally: finally:
self._user_parter_running = False self._user_parter_running = False
@defer.inlineCallbacks async def _part_user(self, user_id):
def _part_user(self, user_id):
"""Causes the given user_id to leave all the rooms they're joined to """Causes the given user_id to leave all the rooms they're joined to
Returns: Returns:
@ -211,11 +205,11 @@ class DeactivateAccountHandler(BaseHandler):
""" """
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
rooms_for_user = yield self.store.get_rooms_for_user(user_id) rooms_for_user = await self.store.get_rooms_for_user(user_id)
for room_id in rooms_for_user: for room_id in rooms_for_user:
logger.info("User parter parting %r from %r", user_id, room_id) logger.info("User parter parting %r from %r", user_id, room_id)
try: try:
yield self._room_member_handler.update_membership( await self._room_member_handler.update_membership(
create_requester(user), create_requester(user),
user, user,
room_id, room_id,